diff options
Diffstat (limited to 'python')
33 files changed, 1303 insertions, 619 deletions
diff --git a/python/google/protobuf/descriptor.py b/python/google/protobuf/descriptor.py index 873af306..e1f2e3b7 100755 --- a/python/google/protobuf/descriptor.py +++ b/python/google/protobuf/descriptor.py @@ -171,13 +171,6 @@ class _NestedDescriptorBase(DescriptorBase): self._serialized_start = serialized_start self._serialized_end = serialized_end - def GetTopLevelContainingType(self): - """Returns the root if this is a nested type, or itself if its the root.""" - desc = self - while desc.containing_type is not None: - desc = desc.containing_type - return desc - def CopyToProto(self, proto): """Copies this to the matching proto in descriptor_pb2. @@ -497,7 +490,7 @@ class FieldDescriptor(DescriptorBase): def __new__(cls, name, full_name, index, number, type, cpp_type, label, default_value, message_type, enum_type, containing_type, is_extension, extension_scope, options=None, - has_default_value=True, containing_oneof=None): + has_default_value=True, containing_oneof=None, json_name=None): _message.Message._CheckCalledFromGeneratedFile() if is_extension: return _message.default_pool.FindExtensionByName(full_name) @@ -507,7 +500,7 @@ class FieldDescriptor(DescriptorBase): def __init__(self, name, full_name, index, number, type, cpp_type, label, default_value, message_type, enum_type, containing_type, is_extension, extension_scope, options=None, - has_default_value=True, containing_oneof=None): + has_default_value=True, containing_oneof=None, json_name=None): """The arguments are as described in the description of FieldDescriptor attributes above. @@ -519,6 +512,10 @@ class FieldDescriptor(DescriptorBase): self.name = name self.full_name = full_name self._camelcase_name = None + if json_name is None: + self.json_name = _ToJsonName(name) + else: + self.json_name = json_name self.index = index self.number = number self.type = type @@ -894,6 +891,31 @@ def _ToCamelCase(name): return ''.join(result) +def _OptionsOrNone(descriptor_proto): + """Returns the value of the field `options`, or None if it is not set.""" + if descriptor_proto.HasField('options'): + return descriptor_proto.options + else: + return None + + +def _ToJsonName(name): + """Converts name to Json name and returns it.""" + capitalize_next = False + result = [] + + for c in name: + if c == '_': + capitalize_next = True + elif capitalize_next: + result.append(c.upper()) + capitalize_next = False + else: + result += c + + return ''.join(result) + + def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True, syntax=None): """Make a protobuf Descriptor given a DescriptorProto protobuf. @@ -970,6 +992,10 @@ def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True, full_name = '.'.join(full_message_name + [field_proto.name]) enum_desc = None nested_desc = None + if field_proto.json_name: + json_name = field_proto.json_name + else: + json_name = None if field_proto.HasField('type_name'): type_name = field_proto.type_name full_type_name = '.'.join(full_message_name + @@ -984,10 +1010,11 @@ def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True, field_proto.number, field_proto.type, FieldDescriptor.ProtoTypeToCppProtoType(field_proto.type), field_proto.label, None, nested_desc, enum_desc, None, False, None, - options=field_proto.options, has_default_value=False) + options=_OptionsOrNone(field_proto), has_default_value=False, + json_name=json_name) fields.append(field) desc_name = '.'.join(full_message_name) return Descriptor(desc_proto.name, desc_name, None, None, fields, list(nested_types.values()), list(enum_types.values()), [], - options=desc_proto.options) + options=_OptionsOrNone(desc_proto)) diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py index 5c055ab9..5f43ee5f 100644 --- a/python/google/protobuf/descriptor_pool.py +++ b/python/google/protobuf/descriptor_pool.py @@ -62,7 +62,7 @@ from google.protobuf import descriptor_database from google.protobuf import text_encoding -_USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS +_USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS # pylint: disable=protected-access def _NormalizeFullyQualifiedName(name): @@ -80,6 +80,14 @@ def _NormalizeFullyQualifiedName(name): return name.lstrip('.') +def _OptionsOrNone(descriptor_proto): + """Returns the value of the field `options`, or None if it is not set.""" + if descriptor_proto.HasField('options'): + return descriptor_proto.options + else: + return None + + class DescriptorPool(object): """A collection of protobufs dynamically constructed by descriptor protos.""" @@ -326,78 +334,61 @@ class DescriptorPool(object): name=file_proto.name, package=file_proto.package, syntax=file_proto.syntax, - options=file_proto.options, + options=_OptionsOrNone(file_proto), serialized_pb=file_proto.SerializeToString(), dependencies=direct_deps, public_dependencies=public_deps) - if _USE_C_DESCRIPTORS: - # When using C++ descriptors, all objects defined in the file were added - # to the C++ database when the FileDescriptor was built above. - # Just add them to this descriptor pool. - def _AddMessageDescriptor(message_desc): - self._descriptors[message_desc.full_name] = message_desc - for nested in message_desc.nested_types: - _AddMessageDescriptor(nested) - for enum_type in message_desc.enum_types: - _AddEnumDescriptor(enum_type) - def _AddEnumDescriptor(enum_desc): - self._enum_descriptors[enum_desc.full_name] = enum_desc - for message_type in file_descriptor.message_types_by_name.values(): - _AddMessageDescriptor(message_type) - for enum_type in file_descriptor.enum_types_by_name.values(): - _AddEnumDescriptor(enum_type) + scope = {} + + # This loop extracts all the message and enum types from all the + # dependencies of the file_proto. This is necessary to create the + # scope of available message types when defining the passed in + # file proto. + for dependency in built_deps: + scope.update(self._ExtractSymbols( + dependency.message_types_by_name.values())) + scope.update((_PrefixWithDot(enum.full_name), enum) + for enum in dependency.enum_types_by_name.values()) + + for message_type in file_proto.message_type: + message_desc = self._ConvertMessageDescriptor( + message_type, file_proto.package, file_descriptor, scope, + file_proto.syntax) + file_descriptor.message_types_by_name[message_desc.name] = ( + message_desc) + + for enum_type in file_proto.enum_type: + file_descriptor.enum_types_by_name[enum_type.name] = ( + self._ConvertEnumDescriptor(enum_type, file_proto.package, + file_descriptor, None, scope)) + + for index, extension_proto in enumerate(file_proto.extension): + extension_desc = self._MakeFieldDescriptor( + extension_proto, file_proto.package, index, is_extension=True) + extension_desc.containing_type = self._GetTypeFromScope( + file_descriptor.package, extension_proto.extendee, scope) + self._SetFieldType(extension_proto, extension_desc, + file_descriptor.package, scope) + file_descriptor.extensions_by_name[extension_desc.name] = ( + extension_desc) + + for desc_proto in file_proto.message_type: + self._SetAllFieldTypes(file_proto.package, desc_proto, scope) + + if file_proto.package: + desc_proto_prefix = _PrefixWithDot(file_proto.package) else: - scope = {} - - # This loop extracts all the message and enum types from all the - # dependencies of the file_proto. This is necessary to create the - # scope of available message types when defining the passed in - # file proto. - for dependency in built_deps: - scope.update(self._ExtractSymbols( - dependency.message_types_by_name.values())) - scope.update((_PrefixWithDot(enum.full_name), enum) - for enum in dependency.enum_types_by_name.values()) - - for message_type in file_proto.message_type: - message_desc = self._ConvertMessageDescriptor( - message_type, file_proto.package, file_descriptor, scope, - file_proto.syntax) - file_descriptor.message_types_by_name[message_desc.name] = ( - message_desc) - - for enum_type in file_proto.enum_type: - file_descriptor.enum_types_by_name[enum_type.name] = ( - self._ConvertEnumDescriptor(enum_type, file_proto.package, - file_descriptor, None, scope)) - - for index, extension_proto in enumerate(file_proto.extension): - extension_desc = self._MakeFieldDescriptor( - extension_proto, file_proto.package, index, is_extension=True) - extension_desc.containing_type = self._GetTypeFromScope( - file_descriptor.package, extension_proto.extendee, scope) - self._SetFieldType(extension_proto, extension_desc, - file_descriptor.package, scope) - file_descriptor.extensions_by_name[extension_desc.name] = ( - extension_desc) - - for desc_proto in file_proto.message_type: - self._SetAllFieldTypes(file_proto.package, desc_proto, scope) - - if file_proto.package: - desc_proto_prefix = _PrefixWithDot(file_proto.package) - else: - desc_proto_prefix = '' + desc_proto_prefix = '' - for desc_proto in file_proto.message_type: - desc = self._GetTypeFromScope( - desc_proto_prefix, desc_proto.name, scope) - file_descriptor.message_types_by_name[desc_proto.name] = desc + for desc_proto in file_proto.message_type: + desc = self._GetTypeFromScope( + desc_proto_prefix, desc_proto.name, scope) + file_descriptor.message_types_by_name[desc_proto.name] = desc - for index, service_proto in enumerate(file_proto.service): - file_descriptor.services_by_name[service_proto.name] = ( - self._MakeServiceDescriptor(service_proto, index, scope, - file_proto.package, file_descriptor)) + for index, service_proto in enumerate(file_proto.service): + file_descriptor.services_by_name[service_proto.name] = ( + self._MakeServiceDescriptor(service_proto, index, scope, + file_proto.package, file_descriptor)) self.Add(file_proto) self._file_descriptors[file_proto.name] = file_descriptor @@ -413,6 +404,7 @@ class DescriptorPool(object): package: The package the proto should be located in. file_desc: The file containing this message. scope: Dict mapping short and full symbols to message and enum types. + syntax: string indicating syntax of the file ("proto2" or "proto3") Returns: The added descriptor. @@ -463,7 +455,7 @@ class DescriptorPool(object): nested_types=nested, enum_types=enums, extensions=extensions, - options=desc_proto.options, + options=_OptionsOrNone(desc_proto), is_extendable=is_extendable, extension_ranges=extension_ranges, file=file_desc, @@ -517,7 +509,7 @@ class DescriptorPool(object): file=file_desc, values=values, containing_type=containing_type, - options=enum_proto.options) + options=_OptionsOrNone(enum_proto)) scope['.%s' % enum_name] = desc self._enum_descriptors[enum_name] = desc return desc @@ -562,7 +554,7 @@ class DescriptorPool(object): default_value=None, is_extension=is_extension, extension_scope=None, - options=field_proto.options) + options=_OptionsOrNone(field_proto)) def _SetAllFieldTypes(self, package, desc_proto, scope): """Sets all the descriptor's fields's types. @@ -681,7 +673,7 @@ class DescriptorPool(object): name=value_proto.name, index=index, number=value_proto.number, - options=value_proto.options, + options=_OptionsOrNone(value_proto), type=None) def _MakeServiceDescriptor(self, service_proto, service_index, scope, @@ -711,7 +703,7 @@ class DescriptorPool(object): full_name=service_name, index=service_index, methods=methods, - options=service_proto.options, + options=_OptionsOrNone(service_proto), file=file_desc) return desc @@ -740,7 +732,7 @@ class DescriptorPool(object): containing_service=None, input_type=input_type, output_type=output_type, - options=method_proto.options) + options=_OptionsOrNone(method_proto)) def _ExtractSymbols(self, descriptors): """Pulls out all the symbols from descriptor protos. diff --git a/python/google/protobuf/internal/any_test.proto b/python/google/protobuf/internal/any_test.proto index cd641ca0..76a7ebd6 100644 --- a/python/google/protobuf/internal/any_test.proto +++ b/python/google/protobuf/internal/any_test.proto @@ -30,13 +30,21 @@ // Author: jieluo@google.com (Jie Luo) -syntax = "proto3"; +syntax = "proto2"; package google.protobuf.internal; import "google/protobuf/any.proto"; message TestAny { - google.protobuf.Any value = 1; - int32 int_value = 2; + optional google.protobuf.Any value = 1; + optional int32 int_value = 2; + extensions 10 to max; +} + +message TestAnyExtension1 { + extend TestAny { + optional TestAnyExtension1 extension1 = 98418603; + } + optional int32 i = 15; } diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py index ce46d08c..de13018e 100755 --- a/python/google/protobuf/internal/containers.py +++ b/python/google/protobuf/internal/containers.py @@ -436,9 +436,11 @@ class ScalarMap(MutableMapping): """Simple, type-checked, dict-like container for holding repeated scalars.""" # Disallows assignment to other attributes. - __slots__ = ['_key_checker', '_value_checker', '_values', '_message_listener'] + __slots__ = ['_key_checker', '_value_checker', '_values', '_message_listener', + '_entry_descriptor'] - def __init__(self, message_listener, key_checker, value_checker): + def __init__(self, message_listener, key_checker, value_checker, + entry_descriptor): """ Args: message_listener: A MessageListener implementation. @@ -448,10 +450,12 @@ class ScalarMap(MutableMapping): inserted into this container. value_checker: A type_checkers.ValueChecker instance to run on values inserted into this container. + entry_descriptor: The MessageDescriptor of a map entry: key and value. """ self._message_listener = message_listener self._key_checker = key_checker self._value_checker = value_checker + self._entry_descriptor = entry_descriptor self._values = {} def __getitem__(self, key): @@ -513,6 +517,9 @@ class ScalarMap(MutableMapping): self._values.clear() self._message_listener.Modified() + def GetEntryClass(self): + return self._entry_descriptor._concrete_class + class MessageMap(MutableMapping): @@ -520,9 +527,10 @@ class MessageMap(MutableMapping): # Disallows assignment to other attributes. __slots__ = ['_key_checker', '_values', '_message_listener', - '_message_descriptor'] + '_message_descriptor', '_entry_descriptor'] - def __init__(self, message_listener, message_descriptor, key_checker): + def __init__(self, message_listener, message_descriptor, key_checker, + entry_descriptor): """ Args: message_listener: A MessageListener implementation. @@ -532,10 +540,12 @@ class MessageMap(MutableMapping): inserted into this container. value_checker: A type_checkers.ValueChecker instance to run on values inserted into this container. + entry_descriptor: The MessageDescriptor of a map entry: key and value. """ self._message_listener = message_listener self._message_descriptor = message_descriptor self._key_checker = key_checker + self._entry_descriptor = entry_descriptor self._values = {} def __getitem__(self, key): @@ -613,3 +623,6 @@ class MessageMap(MutableMapping): def clear(self): self._values.clear() self._message_listener.Modified() + + def GetEntryClass(self): + return self._entry_descriptor._concrete_class diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py index 3c8c7935..d4de2d81 100644 --- a/python/google/protobuf/internal/descriptor_pool_test.py +++ b/python/google/protobuf/internal/descriptor_pool_test.py @@ -119,6 +119,7 @@ class DescriptorPoolTest(unittest.TestCase): self.assertEqual('google.protobuf.python.internal.Factory1Message', msg1.full_name) self.assertEqual(None, msg1.containing_type) + self.assertFalse(msg1.has_options) nested_msg1 = msg1.nested_types[0] self.assertEqual('NestedFactory1Message', nested_msg1.name) @@ -202,6 +203,7 @@ class DescriptorPoolTest(unittest.TestCase): self.assertIsInstance(enum1, descriptor.EnumDescriptor) self.assertEqual(0, enum1.values_by_name['FACTORY_1_VALUE_0'].number) self.assertEqual(1, enum1.values_by_name['FACTORY_1_VALUE_1'].number) + self.assertFalse(enum1.has_options) nested_enum1 = self.pool.FindEnumTypeByName( 'google.protobuf.python.internal.Factory1Message.NestedFactory1Enum') @@ -234,6 +236,8 @@ class DescriptorPoolTest(unittest.TestCase): 'google.protobuf.python.internal.Factory1Message.list_value') self.assertEqual(field.name, 'list_value') self.assertEqual(field.label, field.LABEL_REPEATED) + self.assertFalse(field.has_options) + with self.assertRaises(KeyError): self.pool.FindFieldByName('Does not exist') @@ -448,6 +452,7 @@ class EnumField(object): test.assertTrue(field_desc.has_default_value) test.assertEqual(enum_desc.values_by_name[self.default_value].number, field_desc.default_value) + test.assertFalse(enum_desc.values_by_name[self.default_value].has_options) test.assertEqual(msg_desc, field_desc.containing_type) test.assertEqual(enum_desc, field_desc.enum_type) diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py index 623198c8..1f148ab9 100755 --- a/python/google/protobuf/internal/descriptor_test.py +++ b/python/google/protobuf/internal/descriptor_test.py @@ -766,6 +766,8 @@ class MakeDescriptorTest(unittest.TestCase): 'Foo2.Sub.bar_field') self.assertEqual(result.nested_types[0].fields[0].enum_type, result.nested_types[0].enum_types[0]) + self.assertFalse(result.has_options) + self.assertFalse(result.fields[0].has_options) def testMakeDescriptorWithUnsignedIntField(self): file_descriptor_proto = descriptor_pb2.FileDescriptorProto() @@ -818,6 +820,23 @@ class MakeDescriptorTest(unittest.TestCase): self.assertEqual(result.fields[index].camelcase_name, camelcase_names[index]) + def testJsonName(self): + descriptor_proto = descriptor_pb2.DescriptorProto() + descriptor_proto.name = 'TestJsonName' + names = ['field_name', 'fieldName', 'FieldName', + '_field_name', 'FIELD_NAME', 'json_name'] + json_names = ['fieldName', 'fieldName', 'FieldName', + 'FieldName', 'FIELDNAME', '@type'] + for index in range(len(names)): + field = descriptor_proto.field.add() + field.number = index + 1 + field.name = names[index] + field.json_name = '@type' + result = descriptor.MakeDescriptor(descriptor_proto) + for index in range(len(json_names)): + self.assertEqual(result.fields[index].json_name, + json_names[index]) + if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/internal/generator_test.py b/python/google/protobuf/internal/generator_test.py index 83ea5f50..7f13f9da 100755 --- a/python/google/protobuf/internal/generator_test.py +++ b/python/google/protobuf/internal/generator_test.py @@ -227,7 +227,8 @@ class GeneratorTest(unittest.TestCase): [unittest_import_pb2.DESCRIPTOR]) self.assertEqual(unittest_import_pb2.DESCRIPTOR.dependencies, [unittest_import_public_pb2.DESCRIPTOR]) - + self.assertEqual(unittest_import_pb2.DESCRIPTOR.public_dependencies, + [unittest_import_public_pb2.DESCRIPTOR]) def testNoGenericServices(self): self.assertTrue(hasattr(unittest_no_generic_services_pb2, "TestMessage")) self.assertTrue(hasattr(unittest_no_generic_services_pb2, "FOO")) diff --git a/python/google/protobuf/internal/json_format_test.py b/python/google/protobuf/internal/json_format_test.py index a5ee8ace..5ed65622 100644 --- a/python/google/protobuf/internal/json_format_test.py +++ b/python/google/protobuf/internal/json_format_test.py @@ -205,6 +205,15 @@ class JsonFormatTest(JsonFormatBase): parsed_message = json_format_proto3_pb2.TestMessage() self.CheckParseBack(message, parsed_message) + def testIntegersRepresentedAsFloat(self): + message = json_format_proto3_pb2.TestMessage() + json_format.Parse('{"int32Value": -2.147483648e9}', message) + self.assertEqual(message.int32_value, -2147483648) + json_format.Parse('{"int32Value": 1e5}', message) + self.assertEqual(message.int32_value, 100000) + json_format.Parse('{"int32Value": 1.0}', message) + self.assertEqual(message.int32_value, 1) + def testMapFields(self): message = json_format_proto3_pb2.TestMap() message.bool_map[True] = 1 @@ -428,6 +437,9 @@ class JsonFormatTest(JsonFormatBase): ' "value": "hello",' ' "repeatedValue": [11.1, false, null, null]' '}')) + message.Clear() + json_format.Parse('{"value": null}', message) + self.assertEqual(message.value.WhichOneof('kind'), 'null_value') def testListValueMessage(self): message = json_format_proto3_pb2.TestListValue() @@ -600,6 +612,11 @@ class JsonFormatTest(JsonFormatBase): '}', parsed_message) self.assertEqual(message, parsed_message) + # Null and {} should have different behavior for sub message. + self.assertFalse(parsed_message.HasField('message_value')) + json_format.Parse('{"messageValue": {}}', parsed_message) + self.assertTrue(parsed_message.HasField('message_value')) + # Null is not allowed to be used as an element in repeated field. self.assertRaisesRegexp( json_format.ParseError, 'Failed to parse repeatedInt32Value field: ' @@ -621,15 +638,16 @@ class JsonFormatTest(JsonFormatBase): self.CheckError('', r'Failed to load JSON: (Expecting value)|(No JSON).') - def testParseBadEnumValue(self): - self.CheckError( - '{"enumValue": 1}', - 'Enum value must be a string literal with double quotes. ' - 'Type "proto3.EnumType" has no value named 1.') + def testParseEnumValue(self): + message = json_format_proto3_pb2.TestMessage() + text = '{"enumValue": 0}' + json_format.Parse(text, message) + text = '{"enumValue": 1}' + json_format.Parse(text, message) self.CheckError( '{"enumValue": "baz"}', - 'Enum value must be a string literal with double quotes. ' - 'Type "proto3.EnumType" has no value named baz.') + 'Failed to parse enumValue field: Invalid enum value baz ' + 'for enum type proto3.EnumType.') def testParseBadIdentifer(self): self.CheckError('{int32Value: 1}', @@ -672,12 +690,12 @@ class JsonFormatTest(JsonFormatBase): text = '{"int32Value": 0x12345}' self.assertRaises(json_format.ParseError, json_format.Parse, text, message) + self.CheckError('{"int32Value": 1.5}', + 'Failed to parse int32Value field: ' + 'Couldn\'t parse integer: 1.5.') self.CheckError('{"int32Value": 012345}', (r'Failed to load JSON: Expecting \'?,\'? delimiter: ' r'line 1.')) - self.CheckError('{"int32Value": 1.0}', - 'Failed to parse int32Value field: ' - 'Couldn\'t parse integer: 1.0.') self.CheckError('{"int32Value": " 1 "}', 'Failed to parse int32Value field: ' 'Couldn\'t parse integer: " 1 ".') @@ -687,9 +705,6 @@ class JsonFormatTest(JsonFormatBase): self.CheckError('{"int32Value": 12345678901234567890}', 'Failed to parse int32Value field: Value out of range: ' '12345678901234567890.') - self.CheckError('{"int32Value": 1e5}', - 'Failed to parse int32Value field: ' - 'Couldn\'t parse integer: 100000.0.') self.CheckError('{"uint32Value": -1}', 'Failed to parse uint32Value field: ' 'Value out of range: -1.') @@ -810,6 +825,43 @@ class JsonFormatTest(JsonFormatBase): r'"value": 1234}') json_format.Parse(text, message) + def testPreservingProtoFieldNames(self): + message = json_format_proto3_pb2.TestMessage() + message.int32_value = 12345 + self.assertEqual('{\n "int32Value": 12345\n}', + json_format.MessageToJson(message)) + self.assertEqual('{\n "int32_value": 12345\n}', + json_format.MessageToJson(message, False, True)) + + # Parsers accept both original proto field names and lowerCamelCase names. + message = json_format_proto3_pb2.TestMessage() + json_format.Parse('{"int32Value": 54321}', message) + self.assertEqual(54321, message.int32_value) + json_format.Parse('{"int32_value": 12345}', message) + self.assertEqual(12345, message.int32_value) + + def testParseDict(self): + expected = 12345 + js_dict = {'int32Value': expected} + message = json_format_proto3_pb2.TestMessage() + json_format.ParseDict(js_dict, message) + self.assertEqual(expected, message.int32_value) + + def testMessageToDict(self): + message = json_format_proto3_pb2.TestMessage() + message.int32_value = 12345 + expected = {'int32Value': 12345} + self.assertEqual(expected, + json_format.MessageToDict(message)) + + def testJsonName(self): + message = json_format_proto3_pb2.TestCustomJsonName() + message.value = 12345 + self.assertEqual('{\n "@value": 12345\n}', + json_format.MessageToJson(message)) + parsed_message = json_format_proto3_pb2.TestCustomJsonName() + self.CheckParseBack(message, parsed_message) + if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 1e95adf9..9986c0d9 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -67,6 +67,7 @@ from google.protobuf import text_format from google.protobuf.internal import api_implementation from google.protobuf.internal import packed_field_test_pb2 from google.protobuf.internal import test_util +from google.protobuf.internal import testing_refleaks from google.protobuf import message from google.protobuf.internal import _parameterized @@ -88,10 +89,13 @@ def IsNegInf(val): return isinf(val) and (val < 0) -@_parameterized.Parameters( - (unittest_pb2), - (unittest_proto3_arena_pb2)) -class MessageTest(unittest.TestCase): +BaseTestCase = testing_refleaks.BaseTestCase + + +@_parameterized.NamedParameters( + ('_proto2', unittest_pb2), + ('_proto3', unittest_proto3_arena_pb2)) +class MessageTest(BaseTestCase): def testBadUtf8String(self, message_module): if api_implementation.Type() != 'python': @@ -957,7 +961,7 @@ class MessageTest(unittest.TestCase): # Class to test proto2-only features (required, extensions, etc.) -class Proto2Test(unittest.TestCase): +class Proto2Test(BaseTestCase): def testFieldPresence(self): message = unittest_pb2.TestAllTypes() @@ -1113,6 +1117,7 @@ class Proto2Test(unittest.TestCase): optional_bytes=b'x', optionalgroup={'a': 400}, optional_nested_message={'bb': 500}, + optional_foreign_message={}, optional_nested_enum='BAZ', repeatedgroup=[{'a': 600}, {'a': 700}], @@ -1125,8 +1130,12 @@ class Proto2Test(unittest.TestCase): self.assertEqual(300.5, message.optional_float) self.assertEqual(b'x', message.optional_bytes) self.assertEqual(400, message.optionalgroup.a) - self.assertIsInstance(message.optional_nested_message, unittest_pb2.TestAllTypes.NestedMessage) + self.assertIsInstance(message.optional_nested_message, + unittest_pb2.TestAllTypes.NestedMessage) self.assertEqual(500, message.optional_nested_message.bb) + self.assertTrue(message.HasField('optional_foreign_message')) + self.assertEqual(message.optional_foreign_message, + unittest_pb2.ForeignMessage()) self.assertEqual(unittest_pb2.TestAllTypes.BAZ, message.optional_nested_enum) self.assertEqual(2, len(message.repeatedgroup)) @@ -1164,7 +1173,7 @@ class Proto2Test(unittest.TestCase): # Class to test proto3-only features/behavior (updated field presence & enums) -class Proto3Test(unittest.TestCase): +class Proto3Test(BaseTestCase): # Utility method for comparing equality with a map. def assertMapIterEquals(self, map_iter, dict_value): @@ -1720,7 +1729,7 @@ class Proto3Test(unittest.TestCase): -class ValidTypeNamesTest(unittest.TestCase): +class ValidTypeNamesTest(BaseTestCase): def assertImportFromName(self, msg, base_name): # Parse <type 'module.class_name'> to extra 'some.name' as a string. @@ -1741,7 +1750,7 @@ class ValidTypeNamesTest(unittest.TestCase): self.assertImportFromName(pb.repeated_int32, 'Scalar') self.assertImportFromName(pb.repeated_nested_message, 'Composite') -class PackedFieldTest(unittest.TestCase): +class PackedFieldTest(BaseTestCase): def setMessage(self, message): message.repeated_int32.append(1) @@ -1800,10 +1809,14 @@ class PackedFieldTest(unittest.TestCase): @unittest.skipIf(api_implementation.Type() != 'cpp', 'explicit tests of the C++ implementation') -class OversizeProtosTest(unittest.TestCase): - - def setUp(self): - self.file_desc = """ +class OversizeProtosTest(BaseTestCase): + + @classmethod + def setUpClass(cls): + # At the moment, reference cycles between DescriptorPool and Message classes + # are not detected and these objects are never freed. + # To avoid errors with ReferenceLeakChecker, we create the class only once. + file_desc = """ name: "f/f.msg2" package: "f" message_type { @@ -1828,10 +1841,12 @@ class OversizeProtosTest(unittest.TestCase): """ pool = descriptor_pool.DescriptorPool() desc = descriptor_pb2.FileDescriptorProto() - text_format.Parse(self.file_desc, desc) + text_format.Parse(file_desc, desc) pool.Add(desc) - self.proto_cls = message_factory.MessageFactory(pool).GetPrototype( + cls.proto_cls = message_factory.MessageFactory(pool).GetPrototype( pool.FindMessageTypeByName('f.msg2')) + + def setUp(self): self.p = self.proto_cls() self.p.field.payload = 'c' * (1024 * 1024 * 64 + 1) self.p_serialized = self.p.SerializeToString() diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index c0d0ad45..60b4baad 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -380,13 +380,15 @@ def _GetInitializeDefaultForMap(field): if _IsMessageMapField(field): def MakeMessageMapDefault(message): return containers.MessageMap( - message._listener_for_children, value_field.message_type, key_checker) + message._listener_for_children, value_field.message_type, key_checker, + field.message_type) return MakeMessageMapDefault else: value_checker = type_checkers.GetTypeChecker(value_field) def MakePrimitiveMapDefault(message): return containers.ScalarMap( - message._listener_for_children, key_checker, value_checker) + message._listener_for_children, key_checker, value_checker, + field.message_type) return MakePrimitiveMapDefault def _DefaultValueConstructorForField(field): diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py index 6f3b818a..dad79c37 100755 --- a/python/google/protobuf/internal/reflection_test.py +++ b/python/google/protobuf/internal/reflection_test.py @@ -60,9 +60,13 @@ from google.protobuf.internal import more_messages_pb2 from google.protobuf.internal import message_set_extensions_pb2 from google.protobuf.internal import wire_format from google.protobuf.internal import test_util +from google.protobuf.internal import testing_refleaks from google.protobuf.internal import decoder +BaseTestCase = testing_refleaks.BaseTestCase + + class _MiniDecoder(object): """Decodes a stream of values from a string. @@ -108,7 +112,7 @@ class _MiniDecoder(object): return self._pos == len(self._bytes) -class ReflectionTest(unittest.TestCase): +class ReflectionTest(BaseTestCase): def assertListsEqual(self, values, others): self.assertEqual(len(values), len(others)) @@ -1552,6 +1556,20 @@ class ReflectionTest(unittest.TestCase): self.assertFalse(proto.HasField('optional_foreign_message')) self.assertEqual(0, proto.optional_foreign_message.c) + def testDisconnectingInOneof(self): + m = unittest_pb2.TestOneof2() # This message has two messages in a oneof. + m.foo_message.qux_int = 5 + sub_message = m.foo_message + # Accessing another message's field does not clear the first one + self.assertEqual(m.foo_lazy_message.qux_int, 0) + self.assertEqual(m.foo_message.qux_int, 5) + # But mutating another message in the oneof detaches the first one. + m.foo_lazy_message.qux_int = 6 + self.assertEqual(m.foo_message.qux_int, 0) + # The reference we got above was detached and is still valid. + self.assertEqual(sub_message.qux_int, 5) + sub_message.qux_int = 7 + def testOneOf(self): proto = unittest_pb2.TestAllTypes() proto.oneof_uint32 = 10 @@ -1810,7 +1828,7 @@ class ReflectionTest(unittest.TestCase): # into separate TestCase classes. -class TestAllTypesEqualityTest(unittest.TestCase): +class TestAllTypesEqualityTest(BaseTestCase): def setUp(self): self.first_proto = unittest_pb2.TestAllTypes() @@ -1826,7 +1844,7 @@ class TestAllTypesEqualityTest(unittest.TestCase): self.assertEqual(self.first_proto, self.second_proto) -class FullProtosEqualityTest(unittest.TestCase): +class FullProtosEqualityTest(BaseTestCase): """Equality tests using completely-full protos as a starting point.""" @@ -1912,7 +1930,7 @@ class FullProtosEqualityTest(unittest.TestCase): self.assertEqual(self.first_proto, self.second_proto) -class ExtensionEqualityTest(unittest.TestCase): +class ExtensionEqualityTest(BaseTestCase): def testExtensionEquality(self): first_proto = unittest_pb2.TestAllExtensions() @@ -1945,7 +1963,7 @@ class ExtensionEqualityTest(unittest.TestCase): self.assertEqual(first_proto, second_proto) -class MutualRecursionEqualityTest(unittest.TestCase): +class MutualRecursionEqualityTest(BaseTestCase): def testEqualityWithMutualRecursion(self): first_proto = unittest_pb2.TestMutualRecursionA() @@ -1957,7 +1975,7 @@ class MutualRecursionEqualityTest(unittest.TestCase): self.assertEqual(first_proto, second_proto) -class ByteSizeTest(unittest.TestCase): +class ByteSizeTest(BaseTestCase): def setUp(self): self.proto = unittest_pb2.TestAllTypes() @@ -2253,7 +2271,7 @@ class ByteSizeTest(unittest.TestCase): # * Handling of empty submessages (with and without "has" # bits set). -class SerializationTest(unittest.TestCase): +class SerializationTest(BaseTestCase): def testSerializeEmtpyMessage(self): first_proto = unittest_pb2.TestAllTypes() @@ -2814,7 +2832,7 @@ class SerializationTest(unittest.TestCase): self.assertEqual(3, proto.repeated_int32[2]) -class OptionsTest(unittest.TestCase): +class OptionsTest(BaseTestCase): def testMessageOptions(self): proto = message_set_extensions_pb2.TestMessageSet() @@ -2841,7 +2859,7 @@ class OptionsTest(unittest.TestCase): -class ClassAPITest(unittest.TestCase): +class ClassAPITest(BaseTestCase): @unittest.skipIf( api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, @@ -2924,6 +2942,9 @@ class ClassAPITest(unittest.TestCase): text_format.Merge(file_descriptor_str, file_descriptor) return file_descriptor.SerializeToString() + @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable') + # This test can only run once; the second time, it raises errors about + # conflicting message descriptors. def testParsingFlatClassWithExplicitClassDeclaration(self): """Test that the generated class can parse a flat message.""" # TODO(xiaofeng): This test fails with cpp implemetnation in the call @@ -2948,6 +2969,7 @@ class ClassAPITest(unittest.TestCase): text_format.Merge(msg_str, msg) self.assertEqual(msg.flat, [0, 1, 2]) + @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable') def testParsingFlatClass(self): """Test that the generated class can parse a flat message.""" file_descriptor = descriptor_pb2.FileDescriptorProto() @@ -2963,6 +2985,7 @@ class ClassAPITest(unittest.TestCase): text_format.Merge(msg_str, msg) self.assertEqual(msg.flat, [0, 1, 2]) + @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable') def testParsingNestedClass(self): """Test that the generated class can parse a nested message.""" file_descriptor = descriptor_pb2.FileDescriptorProto() diff --git a/python/google/protobuf/internal/testing_refleaks.py b/python/google/protobuf/internal/testing_refleaks.py new file mode 100644 index 00000000..b2787901 --- /dev/null +++ b/python/google/protobuf/internal/testing_refleaks.py @@ -0,0 +1,124 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""A subclass of unittest.TestCase which checks for reference leaks. + +To use: +- Use testing_refleak.BaseTestCase instead of unittest.TestCase +- Configure and compile Python with --with-pydebug + +If sys.gettotalrefcount() is not available (because Python was built without +the Py_DEBUG option), then this module is a no-op and tests will run normally. +""" + +import copy_reg +import gc +import sys + +try: + import unittest2 as unittest #PY26 +except ImportError: + import unittest + + +class LocalTestResult(unittest.TestResult): + """A TestResult which forwards events to a parent object, except for Skips.""" + + def __init__(self, parent_result): + unittest.TestResult.__init__(self) + self.parent_result = parent_result + + def addError(self, test, error): + self.parent_result.addError(test, error) + + def addFailure(self, test, error): + self.parent_result.addFailure(test, error) + + def addSkip(self, test, reason): + pass + + +class ReferenceLeakCheckerTestCase(unittest.TestCase): + """A TestCase which runs tests multiple times, collecting reference counts.""" + + NB_RUNS = 3 + + def run(self, result=None): + # python_message.py registers all Message classes to some pickle global + # registry, which makes the classes immortal. + # We save a copy of this registry, and reset it before we could references. + self._saved_pickle_registry = copy_reg.dispatch_table.copy() + + # Run the test twice, to warm up the instance attributes. + super(ReferenceLeakCheckerTestCase, self).run(result=result) + super(ReferenceLeakCheckerTestCase, self).run(result=result) + + oldrefcount = 0 + local_result = LocalTestResult(result) + + refcount_deltas = [] + for _ in range(self.NB_RUNS): + oldrefcount = self._getRefcounts() + super(ReferenceLeakCheckerTestCase, self).run(result=local_result) + newrefcount = self._getRefcounts() + refcount_deltas.append(newrefcount - oldrefcount) + print refcount_deltas, self + + try: + self.assertEqual(refcount_deltas, [0] * self.NB_RUNS) + except Exception: # pylint: disable=broad-except + result.addError(self, sys.exc_info()) + + def _getRefcounts(self): + copy_reg.dispatch_table.clear() + copy_reg.dispatch_table.update(self._saved_pickle_registry) + # It is sometimes necessary to gc.collect() multiple times, to ensure + # that all objects can be collected. + gc.collect() + gc.collect() + gc.collect() + return sys.gettotalrefcount() + + +if hasattr(sys, 'gettotalrefcount'): + BaseTestCase = ReferenceLeakCheckerTestCase + SkipReferenceLeakChecker = unittest.skip + +else: + # When PyDEBUG is not enabled, run the tests normally. + BaseTestCase = unittest.TestCase + + def SkipReferenceLeakChecker(reason): + del reason # Don't skip, so don't need a reason. + def Same(func): + return func + return Same + + diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py index 0e38e0e9..ab481ab4 100755 --- a/python/google/protobuf/internal/text_format_test.py +++ b/python/google/protobuf/internal/text_format_test.py @@ -52,6 +52,7 @@ from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_pb2 from google.protobuf import unittest_proto3_arena_pb2 from google.protobuf.internal import api_implementation +from google.protobuf.internal import any_test_pb2 as test_extend_any from google.protobuf.internal import test_util from google.protobuf.internal import message_set_extensions_pb2 from google.protobuf import descriptor_pool @@ -684,6 +685,21 @@ class Proto2Tests(TextFormatBase): self.assertEqual(23, message.message_set.Extensions[ext1].i) self.assertEqual('foo', message.message_set.Extensions[ext2].str) + def testExtensionInsideAnyMessage(self): + message = test_extend_any.TestAny() + text = ('value {\n' + ' [type.googleapis.com/google.protobuf.internal.TestAny] {\n' + ' [google.protobuf.internal.TestAnyExtension1.extension1] {\n' + ' i: 10\n' + ' }\n' + ' }\n' + '}\n') + text_format.Merge(text, message, descriptor_pool=descriptor_pool.Default()) + self.CompareToGoldenText( + text_format.MessageToString( + message, descriptor_pool=descriptor_pool.Default()), + text) + def testParseMessageByFieldNumber(self): message = unittest_pb2.TestAllTypes() text = ('34: 1\n' 'repeated_uint64: 2\n') @@ -1184,7 +1200,8 @@ class TokenizerTest(unittest.TestCase): 'ID7 : "aa\\"bb"\n\n\n\n ID8: {A:inf B:-inf C:true D:false}\n' 'ID9: 22 ID10: -111111111111111111 ID11: -22\n' 'ID12: 2222222222222222222 ID13: 1.23456f ID14: 1.2e+2f ' - 'false_bool: 0 true_BOOL:t \n true_bool1: 1 false_BOOL1:f ') + 'false_bool: 0 true_BOOL:t \n true_bool1: 1 false_BOOL1:f ' + 'False_bool: False True_bool: True') tokenizer = text_format.Tokenizer(text.splitlines()) methods = [(tokenizer.ConsumeIdentifier, 'identifier1'), ':', (tokenizer.ConsumeString, 'string1'), @@ -1228,7 +1245,11 @@ class TokenizerTest(unittest.TestCase): (tokenizer.ConsumeIdentifier, 'true_bool1'), ':', (tokenizer.ConsumeBool, True), (tokenizer.ConsumeIdentifier, 'false_BOOL1'), ':', - (tokenizer.ConsumeBool, False)] + (tokenizer.ConsumeBool, False), + (tokenizer.ConsumeIdentifier, 'False_bool'), ':', + (tokenizer.ConsumeBool, False), + (tokenizer.ConsumeIdentifier, 'True_bool'), ':', + (tokenizer.ConsumeBool, True)] i = 0 while not tokenizer.AtEnd(): diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py index 84073f1c..d614eaa8 100755 --- a/python/google/protobuf/internal/unknown_fields_test.py +++ b/python/google/protobuf/internal/unknown_fields_test.py @@ -47,16 +47,20 @@ from google.protobuf.internal import encoder from google.protobuf.internal import message_set_extensions_pb2 from google.protobuf.internal import missing_enum_values_pb2 from google.protobuf.internal import test_util +from google.protobuf.internal import testing_refleaks from google.protobuf.internal import type_checkers +BaseTestCase = testing_refleaks.BaseTestCase + + def SkipIfCppImplementation(func): return unittest.skipIf( api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, 'C++ implementation does not expose unknown fields to Python')(func) -class UnknownFieldsTest(unittest.TestCase): +class UnknownFieldsTest(BaseTestCase): def setUp(self): self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR @@ -140,7 +144,7 @@ class UnknownFieldsTest(unittest.TestCase): b'', message.repeated_nested_message[0].SerializeToString()) -class UnknownFieldsAccessorsTest(unittest.TestCase): +class UnknownFieldsAccessorsTest(BaseTestCase): def setUp(self): self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR @@ -149,21 +153,18 @@ class UnknownFieldsAccessorsTest(unittest.TestCase): self.all_fields_data = self.all_fields.SerializeToString() self.empty_message = unittest_pb2.TestEmptyMessage() self.empty_message.ParseFromString(self.all_fields_data) - if api_implementation.Type() != 'cpp': - # _unknown_fields is an implementation detail. - self.unknown_fields = self.empty_message._unknown_fields - # All the tests that use GetField() check an implementation detail of the - # Python implementation, which stores unknown fields as serialized strings. - # These tests are skipped by the C++ implementation: it's enough to check that - # the message is correctly serialized. + # GetUnknownField() checks a detail of the Python implementation, which stores + # unknown fields as serialized strings. It cannot be used by the C++ + # implementation: it's enough to check that the message is correctly + # serialized. - def GetField(self, name): + def GetUnknownField(self, name): field_descriptor = self.descriptor.fields_by_name[name] wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] field_tag = encoder.TagBytes(field_descriptor.number, wire_type) result_dict = {} - for tag_bytes, value in self.unknown_fields: + for tag_bytes, value in self.empty_message._unknown_fields: if tag_bytes == field_tag: decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0] decoder(value, 0, len(value), self.all_fields, result_dict) @@ -171,37 +172,37 @@ class UnknownFieldsAccessorsTest(unittest.TestCase): @SkipIfCppImplementation def testEnum(self): - value = self.GetField('optional_nested_enum') + value = self.GetUnknownField('optional_nested_enum') self.assertEqual(self.all_fields.optional_nested_enum, value) @SkipIfCppImplementation def testRepeatedEnum(self): - value = self.GetField('repeated_nested_enum') + value = self.GetUnknownField('repeated_nested_enum') self.assertEqual(self.all_fields.repeated_nested_enum, value) @SkipIfCppImplementation def testVarint(self): - value = self.GetField('optional_int32') + value = self.GetUnknownField('optional_int32') self.assertEqual(self.all_fields.optional_int32, value) @SkipIfCppImplementation def testFixed32(self): - value = self.GetField('optional_fixed32') + value = self.GetUnknownField('optional_fixed32') self.assertEqual(self.all_fields.optional_fixed32, value) @SkipIfCppImplementation def testFixed64(self): - value = self.GetField('optional_fixed64') + value = self.GetUnknownField('optional_fixed64') self.assertEqual(self.all_fields.optional_fixed64, value) @SkipIfCppImplementation def testLengthDelimited(self): - value = self.GetField('optional_string') + value = self.GetUnknownField('optional_string') self.assertEqual(self.all_fields.optional_string, value) @SkipIfCppImplementation def testGroup(self): - value = self.GetField('optionalgroup') + value = self.GetUnknownField('optionalgroup') self.assertEqual(self.all_fields.optionalgroup, value) def testCopyFrom(self): @@ -241,43 +242,41 @@ class UnknownFieldsAccessorsTest(unittest.TestCase): self.assertEqual(message.SerializeToString(), self.all_fields_data) -class UnknownEnumValuesTest(unittest.TestCase): +class UnknownEnumValuesTest(BaseTestCase): def setUp(self): self.descriptor = missing_enum_values_pb2.TestEnumValues.DESCRIPTOR self.message = missing_enum_values_pb2.TestEnumValues() + # TestEnumValues.ZERO = 0, but does not exist in the other NestedEnum. self.message.optional_nested_enum = ( - missing_enum_values_pb2.TestEnumValues.ZERO) + missing_enum_values_pb2.TestEnumValues.ZERO) self.message.repeated_nested_enum.extend([ - missing_enum_values_pb2.TestEnumValues.ZERO, - missing_enum_values_pb2.TestEnumValues.ONE, - ]) + missing_enum_values_pb2.TestEnumValues.ZERO, + missing_enum_values_pb2.TestEnumValues.ONE, + ]) self.message.packed_nested_enum.extend([ - missing_enum_values_pb2.TestEnumValues.ZERO, - missing_enum_values_pb2.TestEnumValues.ONE, - ]) + missing_enum_values_pb2.TestEnumValues.ZERO, + missing_enum_values_pb2.TestEnumValues.ONE, + ]) self.message_data = self.message.SerializeToString() self.missing_message = missing_enum_values_pb2.TestMissingEnumValues() self.missing_message.ParseFromString(self.message_data) - if api_implementation.Type() != 'cpp': - # _unknown_fields is an implementation detail. - self.unknown_fields = self.missing_message._unknown_fields - # All the tests that use GetField() check an implementation detail of the - # Python implementation, which stores unknown fields as serialized strings. - # These tests are skipped by the C++ implementation: it's enough to check that - # the message is correctly serialized. + # GetUnknownField() checks a detail of the Python implementation, which stores + # unknown fields as serialized strings. It cannot be used by the C++ + # implementation: it's enough to check that the message is correctly + # serialized. - def GetField(self, name): + def GetUnknownField(self, name): field_descriptor = self.descriptor.fields_by_name[name] wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] field_tag = encoder.TagBytes(field_descriptor.number, wire_type) result_dict = {} - for tag_bytes, value in self.unknown_fields: + for tag_bytes, value in self.missing_message._unknown_fields: if tag_bytes == field_tag: decoder = missing_enum_values_pb2.TestEnumValues._decoders_by_tag[ - tag_bytes][0] + tag_bytes][0] decoder(value, 0, len(value), self.message, result_dict) return result_dict[field_descriptor] @@ -294,21 +293,39 @@ class UnknownEnumValuesTest(unittest.TestCase): # default value. self.assertEqual(missing.optional_nested_enum, 0) - @SkipIfCppImplementation def testUnknownEnumValue(self): + if api_implementation.Type() == 'cpp': + # The CPP implementation of protos (wrongly) allows unknown enum values + # for proto2. + self.assertTrue(self.missing_message.HasField('optional_nested_enum')) + self.assertEqual(self.message.optional_nested_enum, + self.missing_message.optional_nested_enum) + else: + # On the other hand, the Python implementation considers unknown values + # as unknown fields. This is the correct behavior. + self.assertFalse(self.missing_message.HasField('optional_nested_enum')) + value = self.GetUnknownField('optional_nested_enum') + self.assertEqual(self.message.optional_nested_enum, value) + self.missing_message.ClearField('optional_nested_enum') self.assertFalse(self.missing_message.HasField('optional_nested_enum')) - value = self.GetField('optional_nested_enum') - self.assertEqual(self.message.optional_nested_enum, value) - @SkipIfCppImplementation def testUnknownRepeatedEnumValue(self): - value = self.GetField('repeated_nested_enum') - self.assertEqual(self.message.repeated_nested_enum, value) + if api_implementation.Type() == 'cpp': + # For repeated enums, both implementations agree. + self.assertEqual([], self.missing_message.repeated_nested_enum) + else: + self.assertEqual([], self.missing_message.repeated_nested_enum) + value = self.GetUnknownField('repeated_nested_enum') + self.assertEqual(self.message.repeated_nested_enum, value) - @SkipIfCppImplementation def testUnknownPackedEnumValue(self): - value = self.GetField('packed_nested_enum') - self.assertEqual(self.message.packed_nested_enum, value) + if api_implementation.Type() == 'cpp': + # For repeated enums, both implementations agree. + self.assertEqual([], self.missing_message.packed_nested_enum) + else: + self.assertEqual([], self.missing_message.packed_nested_enum) + value = self.GetUnknownField('packed_nested_enum') + self.assertEqual(self.message.packed_nested_enum, value) def testRoundTrip(self): new_message = missing_enum_values_pb2.TestEnumValues() diff --git a/python/google/protobuf/internal/well_known_types.py b/python/google/protobuf/internal/well_known_types.py index 7c5dffd0..d631abee 100644 --- a/python/google/protobuf/internal/well_known_types.py +++ b/python/google/protobuf/internal/well_known_types.py @@ -53,6 +53,7 @@ _NANOS_PER_MICROSECOND = 1000 _MILLIS_PER_SECOND = 1000 _MICROS_PER_SECOND = 1000000 _SECONDS_PER_DAY = 24 * 3600 +_DURATION_SECONDS_MAX = 315576000000 class Error(Exception): @@ -247,6 +248,7 @@ class Duration(object): represent the exact Duration value. For example: "1s", "1.010s", "1.000000100s", "-3.100s" """ + _CheckDurationValid(self.seconds, self.nanos) if self.seconds < 0 or self.nanos < 0: result = '-' seconds = - self.seconds + int((0 - self.nanos) // 1e9) @@ -286,14 +288,17 @@ class Duration(object): try: pos = value.find('.') if pos == -1: - self.seconds = int(value[:-1]) - self.nanos = 0 + seconds = int(value[:-1]) + nanos = 0 else: - self.seconds = int(value[:pos]) + seconds = int(value[:pos]) if value[0] == '-': - self.nanos = int(round(float('-0{0}'.format(value[pos: -1])) *1e9)) + nanos = int(round(float('-0{0}'.format(value[pos: -1])) *1e9)) else: - self.nanos = int(round(float('0{0}'.format(value[pos: -1])) *1e9)) + nanos = int(round(float('0{0}'.format(value[pos: -1])) *1e9)) + _CheckDurationValid(seconds, nanos) + self.seconds = seconds + self.nanos = nanos except ValueError: raise ParseError( 'Couldn\'t parse duration: {0}.'.format(value)) @@ -359,6 +364,17 @@ class Duration(object): self.nanos = nanos +def _CheckDurationValid(seconds, nanos): + if seconds < -_DURATION_SECONDS_MAX or seconds > _DURATION_SECONDS_MAX: + raise Error( + 'Duration is not valid: Seconds {0} must be in range ' + '[-315576000000, 315576000000].'.format(seconds)) + if nanos <= -_NANOS_PER_SECOND or nanos >= _NANOS_PER_SECOND: + raise Error( + 'Duration is not valid: Nanos {0} must be in range ' + '[-999999999, 999999999].'.format(nanos)) + + def _RoundTowardZero(value, divider): """Truncates the remainder part after division.""" # For some languanges, the sign of the remainder is implementation @@ -379,13 +395,16 @@ class FieldMask(object): def ToJsonString(self): """Converts FieldMask to string according to proto3 JSON spec.""" - return ','.join(self.paths) + camelcase_paths = [] + for path in self.paths: + camelcase_paths.append(_SnakeCaseToCamelCase(path)) + return ','.join(camelcase_paths) def FromJsonString(self, value): """Converts string to FieldMask according to proto3 JSON spec.""" self.Clear() for path in value.split(','): - self.paths.append(path) + self.paths.append(_CamelCaseToSnakeCase(path)) def IsValidForDescriptor(self, message_descriptor): """Checks whether the FieldMask is valid for Message Descriptor.""" @@ -472,6 +491,48 @@ def _CheckFieldMaskMessage(message): message_descriptor.full_name)) +def _SnakeCaseToCamelCase(path_name): + """Converts a path name from snake_case to camelCase.""" + result = [] + after_underscore = False + for c in path_name: + if c.isupper(): + raise Error('Fail to print FieldMask to Json string: Path name ' + '{0} must not contain uppercase letters.'.format(path_name)) + if after_underscore: + if c.islower(): + result.append(c.upper()) + after_underscore = False + else: + raise Error('Fail to print FieldMask to Json string: The ' + 'character after a "_" must be a lowercase letter ' + 'in path name {0}.'.format(path_name)) + elif c == '_': + after_underscore = True + else: + result += c + + if after_underscore: + raise Error('Fail to print FieldMask to Json string: Trailing "_" ' + 'in path name {0}.'.format(path_name)) + return ''.join(result) + + +def _CamelCaseToSnakeCase(path_name): + """Converts a field name from camelCase to snake_case.""" + result = [] + for c in path_name: + if c == '_': + raise ParseError('Fail to parse FieldMask: Path name ' + '{0} must not contain "_"s.'.format(path_name)) + if c.isupper(): + result += '_' + result += c.lower() + else: + result += c + return ''.join(result) + + class _FieldMaskTree(object): """Represents a FieldMask in a tree structure. diff --git a/python/google/protobuf/internal/well_known_types_test.py b/python/google/protobuf/internal/well_known_types_test.py index 2f32ac99..077f630f 100644 --- a/python/google/protobuf/internal/well_known_types_test.py +++ b/python/google/protobuf/internal/well_known_types_test.py @@ -303,6 +303,25 @@ class TimeUtilTest(TimeUtilTestBase): well_known_types.ParseError, 'Couldn\'t parse duration: 1...2s.', message.FromJsonString, '1...2s') + text = '-315576000001.000000000s' + self.assertRaisesRegexp( + well_known_types.Error, + r'Duration is not valid\: Seconds -315576000001 must be in range' + r' \[-315576000000\, 315576000000\].', + message.FromJsonString, text) + text = '315576000001.000000000s' + self.assertRaisesRegexp( + well_known_types.Error, + r'Duration is not valid\: Seconds 315576000001 must be in range' + r' \[-315576000000\, 315576000000\].', + message.FromJsonString, text) + message.seconds = -315576000001 + message.nanos = 0 + self.assertRaisesRegexp( + well_known_types.Error, + r'Duration is not valid\: Seconds -315576000001 must be in range' + r' \[-315576000000\, 315576000000\].', + message.ToJsonString) class FieldMaskTest(unittest.TestCase): @@ -322,6 +341,20 @@ class FieldMaskTest(unittest.TestCase): mask.FromJsonString('foo,bar') self.assertEqual(['foo', 'bar'], mask.paths) + # Test camel case + mask.Clear() + mask.paths.append('foo_bar') + self.assertEqual('fooBar', mask.ToJsonString()) + mask.paths.append('bar_quz') + self.assertEqual('fooBar,barQuz', mask.ToJsonString()) + + mask.FromJsonString('') + self.assertEqual('', mask.ToJsonString()) + mask.FromJsonString('fooBar') + self.assertEqual(['foo_bar'], mask.paths) + mask.FromJsonString('fooBar,barQuz') + self.assertEqual(['foo_bar', 'bar_quz'], mask.paths) + def testDescriptorToFieldMask(self): mask = field_mask_pb2.FieldMask() msg_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR @@ -502,17 +535,68 @@ class FieldMaskTest(unittest.TestCase): nested_src.payload.repeated_int32.append(1234) nested_dst.payload.repeated_int32.append(5678) # Repeated fields will be appended by default. - mask.FromJsonString('payload.repeated_int32') + mask.FromJsonString('payload.repeatedInt32') mask.MergeMessage(nested_src, nested_dst) self.assertEqual(2, len(nested_dst.payload.repeated_int32)) self.assertEqual(5678, nested_dst.payload.repeated_int32[0]) self.assertEqual(1234, nested_dst.payload.repeated_int32[1]) # Change the behavior to replace repeated fields. - mask.FromJsonString('payload.repeated_int32') + mask.FromJsonString('payload.repeatedInt32') mask.MergeMessage(nested_src, nested_dst, False, True) self.assertEqual(1, len(nested_dst.payload.repeated_int32)) self.assertEqual(1234, nested_dst.payload.repeated_int32[0]) + def testSnakeCaseToCamelCase(self): + self.assertEqual('fooBar', + well_known_types._SnakeCaseToCamelCase('foo_bar')) + self.assertEqual('FooBar', + well_known_types._SnakeCaseToCamelCase('_foo_bar')) + self.assertEqual('foo3Bar', + well_known_types._SnakeCaseToCamelCase('foo3_bar')) + + # No uppercase letter is allowed. + self.assertRaisesRegexp( + well_known_types.Error, + 'Fail to print FieldMask to Json string: Path name Foo must ' + 'not contain uppercase letters.', + well_known_types._SnakeCaseToCamelCase, + 'Foo') + # Any character after a "_" must be a lowercase letter. + # 1. "_" cannot be followed by another "_". + # 2. "_" cannot be followed by a digit. + # 3. "_" cannot appear as the last character. + self.assertRaisesRegexp( + well_known_types.Error, + 'Fail to print FieldMask to Json string: The character after a ' + '"_" must be a lowercase letter in path name foo__bar.', + well_known_types._SnakeCaseToCamelCase, + 'foo__bar') + self.assertRaisesRegexp( + well_known_types.Error, + 'Fail to print FieldMask to Json string: The character after a ' + '"_" must be a lowercase letter in path name foo_3bar.', + well_known_types._SnakeCaseToCamelCase, + 'foo_3bar') + self.assertRaisesRegexp( + well_known_types.Error, + 'Fail to print FieldMask to Json string: Trailing "_" in path ' + 'name foo_bar_.', + well_known_types._SnakeCaseToCamelCase, + 'foo_bar_') + + def testCamelCaseToSnakeCase(self): + self.assertEqual('foo_bar', + well_known_types._CamelCaseToSnakeCase('fooBar')) + self.assertEqual('_foo_bar', + well_known_types._CamelCaseToSnakeCase('FooBar')) + self.assertEqual('foo3_bar', + well_known_types._CamelCaseToSnakeCase('foo3Bar')) + self.assertRaisesRegexp( + well_known_types.ParseError, + 'Fail to parse FieldMask: Path name foo_bar must not contain "_"s.', + well_known_types._CamelCaseToSnakeCase, + 'foo_bar') + class StructTest(unittest.TestCase): @@ -529,52 +613,52 @@ class StructTest(unittest.TestCase): struct_list.add_struct()['subkey2'] = 9 self.assertTrue(isinstance(struct, well_known_types.Struct)) - self.assertEquals(5, struct['key1']) - self.assertEquals('abc', struct['key2']) + self.assertEqual(5, struct['key1']) + self.assertEqual('abc', struct['key2']) self.assertIs(True, struct['key3']) - self.assertEquals(11, struct['key4']['subkey']) + self.assertEqual(11, struct['key4']['subkey']) inner_struct = struct_class() inner_struct['subkey2'] = 9 - self.assertEquals([6, 'seven', True, False, None, inner_struct], - list(struct['key5'].items())) + self.assertEqual([6, 'seven', True, False, None, inner_struct], + list(struct['key5'].items())) serialized = struct.SerializeToString() struct2 = struct_pb2.Struct() struct2.ParseFromString(serialized) - self.assertEquals(struct, struct2) + self.assertEqual(struct, struct2) self.assertTrue(isinstance(struct2, well_known_types.Struct)) - self.assertEquals(5, struct2['key1']) - self.assertEquals('abc', struct2['key2']) + self.assertEqual(5, struct2['key1']) + self.assertEqual('abc', struct2['key2']) self.assertIs(True, struct2['key3']) - self.assertEquals(11, struct2['key4']['subkey']) - self.assertEquals([6, 'seven', True, False, None, inner_struct], - list(struct2['key5'].items())) + self.assertEqual(11, struct2['key4']['subkey']) + self.assertEqual([6, 'seven', True, False, None, inner_struct], + list(struct2['key5'].items())) struct_list = struct2['key5'] - self.assertEquals(6, struct_list[0]) - self.assertEquals('seven', struct_list[1]) - self.assertEquals(True, struct_list[2]) - self.assertEquals(False, struct_list[3]) - self.assertEquals(None, struct_list[4]) - self.assertEquals(inner_struct, struct_list[5]) + self.assertEqual(6, struct_list[0]) + self.assertEqual('seven', struct_list[1]) + self.assertEqual(True, struct_list[2]) + self.assertEqual(False, struct_list[3]) + self.assertEqual(None, struct_list[4]) + self.assertEqual(inner_struct, struct_list[5]) struct_list[1] = 7 - self.assertEquals(7, struct_list[1]) + self.assertEqual(7, struct_list[1]) struct_list.add_list().extend([1, 'two', True, False, None]) - self.assertEquals([1, 'two', True, False, None], - list(struct_list[6].items())) + self.assertEqual([1, 'two', True, False, None], + list(struct_list[6].items())) text_serialized = str(struct) struct3 = struct_pb2.Struct() text_format.Merge(text_serialized, struct3) - self.assertEquals(struct, struct3) + self.assertEqual(struct, struct3) struct.get_or_create_struct('key3')['replace'] = 12 - self.assertEquals(12, struct['key3']['replace']) + self.assertEqual(12, struct['key3']['replace']) class AnyTest(unittest.TestCase): diff --git a/python/google/protobuf/json_format.py b/python/google/protobuf/json_format.py index edc0cb50..c42371d0 100644 --- a/python/google/protobuf/json_format.py +++ b/python/google/protobuf/json_format.py @@ -86,7 +86,9 @@ class ParseError(Error): """Thrown in case of parsing error.""" -def MessageToJson(message, including_default_value_fields=False): +def MessageToJson(message, + including_default_value_fields=False, + preserving_proto_field_name=False): """Converts protobuf message to JSON format. Args: @@ -95,14 +97,42 @@ def MessageToJson(message, including_default_value_fields=False): repeated fields, and map fields will always be serialized. If False, only serialize non-empty fields. Singular message fields and oneof fields are not affected by this option. + preserving_proto_field_name: If True, use the original proto field + names as defined in the .proto file. If False, convert the field + names to lowerCamelCase. Returns: A string containing the JSON formatted protocol buffer message. """ - printer = _Printer(including_default_value_fields) + printer = _Printer(including_default_value_fields, + preserving_proto_field_name) return printer.ToJsonString(message) +def MessageToDict(message, + including_default_value_fields=False, + preserving_proto_field_name=False): + """Converts protobuf message to a JSON dictionary. + + Args: + message: The protocol buffers message instance to serialize. + including_default_value_fields: If True, singular primitive fields, + repeated fields, and map fields will always be serialized. If + False, only serialize non-empty fields. Singular message fields + and oneof fields are not affected by this option. + preserving_proto_field_name: If True, use the original proto field + names as defined in the .proto file. If False, convert the field + names to lowerCamelCase. + + Returns: + A dict representation of the JSON formatted protocol buffer message. + """ + printer = _Printer(including_default_value_fields, + preserving_proto_field_name) + # pylint: disable=protected-access + return printer._MessageToJsonObject(message) + + def _IsMapEntry(field): return (field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and field.message_type.has_options and @@ -113,8 +143,10 @@ class _Printer(object): """JSON format printer for protocol message.""" def __init__(self, - including_default_value_fields=False): + including_default_value_fields=False, + preserving_proto_field_name=False): self.including_default_value_fields = including_default_value_fields + self.preserving_proto_field_name = preserving_proto_field_name def ToJsonString(self, message): js = self._MessageToJsonObject(message) @@ -137,7 +169,10 @@ class _Printer(object): try: for field, value in fields: - name = field.camelcase_name + if self.preserving_proto_field_name: + name = field.name + else: + name = field.json_name if _IsMapEntry(field): # Convert a map field. v_field = field.message_type.fields_by_name['value'] @@ -169,7 +204,10 @@ class _Printer(object): field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE) or field.containing_oneof): continue - name = field.camelcase_name + if self.preserving_proto_field_name: + name = field.name + else: + name = field.json_name if name in js: # Skip the field which has been serailized already. continue @@ -328,8 +366,22 @@ def Parse(text, message, ignore_unknown_fields=False): js = json.loads(text, object_pairs_hook=_DuplicateChecker) except ValueError as e: raise ParseError('Failed to load JSON: {0}.'.format(str(e))) + return ParseDict(js, message, ignore_unknown_fields) + + +def ParseDict(js_dict, message, ignore_unknown_fields=False): + """Parses a JSON dictionary representation into a message. + + Args: + js_dict: Dict representation of a JSON message. + message: A protocol buffer message to merge into. + ignore_unknown_fields: If True, do not raise errors for unknown fields. + + Returns: + The same message passed as argument. + """ parser = _Parser(ignore_unknown_fields) - parser.ConvertMessage(js, message) + parser.ConvertMessage(js_dict, message) return message @@ -374,9 +426,13 @@ class _Parser(object): """ names = [] message_descriptor = message.DESCRIPTOR + fields_by_json_name = dict((f.json_name, f) + for f in message_descriptor.fields) for name in js: try: - field = message_descriptor.fields_by_camelcase_name.get(name, None) + field = fields_by_json_name.get(name, None) + if not field: + field = message_descriptor.fields_by_name.get(name, None) if not field: if self.ignore_unknown_fields: continue @@ -399,7 +455,12 @@ class _Parser(object): value = js[name] if value is None: - message.ClearField(field.name) + if (field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE + and field.message_type.full_name == 'google.protobuf.Value'): + sub_message = getattr(message, field.name) + sub_message.null_value = 0 + else: + message.ClearField(field.name) continue # Parse field value. @@ -431,6 +492,7 @@ class _Parser(object): _ConvertScalarFieldValue(item, field)) elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: sub_message = getattr(message, field.name) + sub_message.SetInParent() self.ConvertMessage(value, sub_message) else: setattr(message, field.name, _ConvertScalarFieldValue(value, field)) @@ -574,10 +636,15 @@ def _ConvertScalarFieldValue(value, field, require_str=False): # Convert an enum value. enum_value = field.enum_type.values_by_name.get(value, None) if enum_value is None: - raise ParseError( - 'Enum value must be a string literal with double quotes. ' - 'Type "{0}" has no value named {1}.'.format( - field.enum_type.full_name, value)) + try: + number = int(value) + enum_value = field.enum_type.values_by_number.get(number, None) + except ValueError: + raise ParseError('Invalid enum value {0} for enum type {1}.'.format( + value, field.enum_type.full_name)) + if enum_value is None: + raise ParseError('Invalid enum value {0} for enum type {1}.'.format( + value, field.enum_type.full_name)) return enum_value.number @@ -593,7 +660,7 @@ def _ConvertInteger(value): Raises: ParseError: If an integer couldn't be consumed. """ - if isinstance(value, float): + if isinstance(value, float) and not value.is_integer(): raise ParseError('Couldn\'t parse integer: {0}.'.format(value)) if isinstance(value, six.text_type) and value.find(' ') != -1: diff --git a/python/google/protobuf/message.py b/python/google/protobuf/message.py index 606f735f..aab250e4 100755 --- a/python/google/protobuf/message.py +++ b/python/google/protobuf/message.py @@ -225,10 +225,11 @@ class Message(object): # """ def ListFields(self): """Returns a list of (FieldDescriptor, value) tuples for all - fields in the message which are not empty. A singular field is non-empty - if HasField() would return true, and a repeated field is non-empty if - it contains at least one element. The fields are ordered by field - number""" + fields in the message which are not empty. A message field is + non-empty if HasField() would return true. A singular primitive field + is non-empty if HasField() would return true in proto2 or it is non zero + in proto3. A repeated field is non-empty if it contains at least one + element. The fields are ordered by field number""" raise NotImplementedError def HasField(self, field_name): diff --git a/python/google/protobuf/message_factory.py b/python/google/protobuf/message_factory.py index 1b059d13..8ab1c513 100644 --- a/python/google/protobuf/message_factory.py +++ b/python/google/protobuf/message_factory.py @@ -103,13 +103,8 @@ class MessageFactory(object): result = {} for file_name in files: file_desc = self.pool.FindFileByName(file_name) - for name, msg in file_desc.message_types_by_name.items(): - if file_desc.package: - full_name = '.'.join([file_desc.package, name]) - else: - full_name = msg.name - result[full_name] = self.GetPrototype( - self.pool.FindMessageTypeByName(full_name)) + for desc in file_desc.message_types_by_name.values(): + result[desc.full_name] = self.GetPrototype(desc) # While the extension FieldDescriptors are created by the descriptor pool, # the python classes created in the factory need them to be registered @@ -120,7 +115,7 @@ class MessageFactory(object): # ignore the registration if the original was the same, or raise # an error if they were different. - for name, extension in file_desc.extensions_by_name.items(): + for extension in file_desc.extensions_by_name.values(): if extension.containing_type.full_name not in self._classes: self.GetPrototype(extension.containing_type) extended_class = self._classes[extension.containing_type.full_name] diff --git a/python/google/protobuf/pyext/descriptor.cc b/python/google/protobuf/pyext/descriptor.cc index e6ef5ef5..924ae0b9 100644 --- a/python/google/protobuf/pyext/descriptor.cc +++ b/python/google/protobuf/pyext/descriptor.cc @@ -41,6 +41,7 @@ #include <google/protobuf/pyext/descriptor_containers.h> #include <google/protobuf/pyext/descriptor_pool.h> #include <google/protobuf/pyext/message.h> +#include <google/protobuf/pyext/message_factory.h> #include <google/protobuf/pyext/scoped_pyobject_ptr.h> #if PY_MAJOR_VERSION >= 3 @@ -204,8 +205,9 @@ static PyObject* GetOrBuildOptions(const DescriptorClass *descriptor) { // read-only instance. const Message& options(descriptor->options()); const Descriptor *message_type = options.GetDescriptor(); - CMessageClass* message_class( - cdescriptor_pool::GetMessageClass(pool, message_type)); + PyMessageFactory* message_factory = pool->py_message_factory; + CMessageClass* message_class = message_factory::GetMessageClass( + message_factory, message_type); if (message_class == NULL) { // The Options message was not found in the current DescriptorPool. // This means that the pool cannot contain any extensions to the Options @@ -213,7 +215,9 @@ static PyObject* GetOrBuildOptions(const DescriptorClass *descriptor) { // the chances of successfully parsing the options. PyErr_Clear(); pool = GetDefaultDescriptorPool(); - message_class = cdescriptor_pool::GetMessageClass(pool, message_type); + message_factory = pool->py_message_factory; + message_class = message_factory::GetMessageClass( + message_factory, message_type); } if (message_class == NULL) { PyErr_Format(PyExc_TypeError, "Could not retrieve class for Options: %s", @@ -243,7 +247,7 @@ static PyObject* GetOrBuildOptions(const DescriptorClass *descriptor) { options.SerializeToString(&serialized); io::CodedInputStream input( reinterpret_cast<const uint8*>(serialized.c_str()), serialized.size()); - input.SetExtensionRegistry(pool->pool, pool->message_factory); + input.SetExtensionRegistry(pool->pool, message_factory->message_factory); bool success = cmsg->message->MergePartialFromCodedStream(&input); if (!success) { PyErr_Format(PyExc_ValueError, "Error parsing Options message"); @@ -439,8 +443,9 @@ static PyObject* GetConcreteClass(PyBaseDescriptor* self, void *closure) { // which contains this descriptor. // This might not be the one you expect! For example the returned object does // not know about extensions defined in a custom pool. - CMessageClass* concrete_class(cdescriptor_pool::GetMessageClass( - GetDescriptorPool_FromPool(_GetDescriptor(self)->file()->pool()), + CMessageClass* concrete_class(message_factory::GetMessageClass( + GetDescriptorPool_FromPool( + _GetDescriptor(self)->file()->pool())->py_message_factory, _GetDescriptor(self))); Py_XINCREF(concrete_class); return concrete_class->AsPyObject(); @@ -699,6 +704,10 @@ static PyObject* GetCamelcaseName(PyBaseDescriptor* self, void *closure) { return PyString_FromCppString(_GetDescriptor(self)->camelcase_name()); } +static PyObject* GetJsonName(PyBaseDescriptor* self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->json_name()); +} + static PyObject* GetType(PyBaseDescriptor *self, void *closure) { return PyInt_FromLong(_GetDescriptor(self)->type()); } @@ -888,6 +897,7 @@ static PyGetSetDef Getters[] = { { "full_name", (getter)GetFullName, NULL, "Full name"}, { "name", (getter)GetName, NULL, "Unqualified name"}, { "camelcase_name", (getter)GetCamelcaseName, NULL, "Camelcase name"}, + { "json_name", (getter)GetJsonName, NULL, "Json name"}, { "type", (getter)GetType, NULL, "C++ Type"}, { "cpp_type", (getter)GetCppType, NULL, "C++ Type"}, { "label", (getter)GetLabel, NULL, "Label"}, diff --git a/python/google/protobuf/pyext/descriptor_pool.cc b/python/google/protobuf/pyext/descriptor_pool.cc index cfd98690..a42e5431 100644 --- a/python/google/protobuf/pyext/descriptor_pool.cc +++ b/python/google/protobuf/pyext/descriptor_pool.cc @@ -33,11 +33,11 @@ #include <Python.h> #include <google/protobuf/descriptor.pb.h> -#include <google/protobuf/dynamic_message.h> #include <google/protobuf/pyext/descriptor.h> #include <google/protobuf/pyext/descriptor_database.h> #include <google/protobuf/pyext/descriptor_pool.h> #include <google/protobuf/pyext/message.h> +#include <google/protobuf/pyext/message_factory.h> #include <google/protobuf/pyext/scoped_pyobject_ptr.h> #if PY_MAJOR_VERSION >= 3 @@ -73,18 +73,16 @@ static PyDescriptorPool* _CreateDescriptorPool() { cpool->underlay = NULL; cpool->database = NULL; - DynamicMessageFactory* message_factory = new DynamicMessageFactory(); - // This option might be the default some day. - message_factory->SetDelegateToGeneratedFactory(true); - cpool->message_factory = message_factory; - - // TODO(amauryfa): Rewrite the SymbolDatabase in C so that it uses the same - // storage. - cpool->classes_by_descriptor = - new PyDescriptorPool::ClassesByMessageMap(); cpool->descriptor_options = new hash_map<const void*, PyObject *>(); + cpool->py_message_factory = message_factory::NewMessageFactory( + &PyMessageFactory_Type, cpool); + if (cpool->py_message_factory == NULL) { + Py_DECREF(cpool); + return NULL; + } + return cpool; } @@ -151,20 +149,14 @@ static PyObject* New(PyTypeObject* type, } static void Dealloc(PyDescriptorPool* self) { - typedef PyDescriptorPool::ClassesByMessageMap::iterator iterator; descriptor_pool_map.erase(self->pool); - for (iterator it = self->classes_by_descriptor->begin(); - it != self->classes_by_descriptor->end(); ++it) { - Py_DECREF(it->second); - } - delete self->classes_by_descriptor; + Py_CLEAR(self->py_message_factory); for (hash_map<const void*, PyObject*>::iterator it = self->descriptor_options->begin(); it != self->descriptor_options->end(); ++it) { Py_DECREF(it->second); } delete self->descriptor_options; - delete self->message_factory; delete self->database; delete self->pool; Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self)); @@ -188,35 +180,8 @@ PyObject* FindMessageByName(PyDescriptorPool* self, PyObject* arg) { return PyMessageDescriptor_FromDescriptor(message_descriptor); } -// Add a message class to our database. -int RegisterMessageClass(PyDescriptorPool* self, - const Descriptor* message_descriptor, - CMessageClass* message_class) { - Py_INCREF(message_class); - typedef PyDescriptorPool::ClassesByMessageMap::iterator iterator; - std::pair<iterator, bool> ret = self->classes_by_descriptor->insert( - std::make_pair(message_descriptor, message_class)); - if (!ret.second) { - // Update case: DECREF the previous value. - Py_DECREF(ret.first->second); - ret.first->second = message_class; - } - return 0; -} -// Retrieve the message class added to our database. -CMessageClass* GetMessageClass(PyDescriptorPool* self, - const Descriptor* message_descriptor) { - typedef PyDescriptorPool::ClassesByMessageMap::iterator iterator; - iterator ret = self->classes_by_descriptor->find(message_descriptor); - if (ret == self->classes_by_descriptor->end()) { - PyErr_Format(PyExc_TypeError, "No message class registered for '%s'", - message_descriptor->full_name().c_str()); - return NULL; - } else { - return ret->second; - } -} + PyObject* FindFileByName(PyDescriptorPool* self, PyObject* arg) { Py_ssize_t name_size; @@ -228,11 +193,9 @@ PyObject* FindFileByName(PyDescriptorPool* self, PyObject* arg) { const FileDescriptor* file_descriptor = self->pool->FindFileByName(string(name, name_size)); if (file_descriptor == NULL) { - PyErr_Format(PyExc_KeyError, "Couldn't find file %.200s", - name); + PyErr_Format(PyExc_KeyError, "Couldn't find file %.200s", name); return NULL; } - return PyFileDescriptor_FromDescriptor(file_descriptor); } diff --git a/python/google/protobuf/pyext/descriptor_pool.h b/python/google/protobuf/pyext/descriptor_pool.h index 2a42c112..8de6c60b 100644 --- a/python/google/protobuf/pyext/descriptor_pool.h +++ b/python/google/protobuf/pyext/descriptor_pool.h @@ -38,10 +38,10 @@ namespace google { namespace protobuf { -class MessageFactory; - namespace python { +class PyMessageFactory; + // The (meta) type of all Messages classes. struct CMessageClass; @@ -69,20 +69,10 @@ typedef struct PyDescriptorPool { // This pointer is owned. const DescriptorDatabase* database; - // DynamicMessageFactory used to create C++ instances of messages. - // This object cache the descriptors that were used, so the DescriptorPool - // needs to get rid of it before it can delete itself. - // - // Note: A C++ MessageFactory is different from the Python MessageFactory. - // The C++ one creates messages, when the Python one creates classes. - MessageFactory* message_factory; - - // Make our own mapping to retrieve Python classes from C++ descriptors. - // - // Descriptor pointers stored here are owned by the DescriptorPool above. - // Python references to classes are owned by this PyDescriptorPool. - typedef hash_map<const Descriptor*, CMessageClass*> ClassesByMessageMap; - ClassesByMessageMap* classes_by_descriptor; + // The preferred MessageFactory to be used by descriptors. + // TODO(amauryfa): Don't create the Factory from the DescriptorPool, but + // use the one passed while creating message classes. And remove this member. + PyMessageFactory* py_message_factory; // Cache the options for any kind of descriptor. // Descriptor pointers are owned by the DescriptorPool above. @@ -100,19 +90,6 @@ namespace cdescriptor_pool { const Descriptor* FindMessageTypeByName(PyDescriptorPool* self, const string& name); -// Registers a new Python class for the given message descriptor. -// On error, returns -1 with a Python exception set. -int RegisterMessageClass(PyDescriptorPool* self, - const Descriptor* message_descriptor, - CMessageClass* message_class); - -// Retrieves the Python class registered with the given message descriptor. -// -// Returns a *borrowed* reference if found, otherwise returns NULL with an -// exception set. -CMessageClass* GetMessageClass(PyDescriptorPool* self, - const Descriptor* message_descriptor); - // The functions below are also exposed as methods of the DescriptorPool type. // Looks up a message by name. Returns a PyMessageDescriptor corresponding to diff --git a/python/google/protobuf/pyext/extension_dict.cc b/python/google/protobuf/pyext/extension_dict.cc index 21bbb8c2..dbb7bca0 100644 --- a/python/google/protobuf/pyext/extension_dict.cc +++ b/python/google/protobuf/pyext/extension_dict.cc @@ -39,8 +39,8 @@ #include <google/protobuf/dynamic_message.h> #include <google/protobuf/message.h> #include <google/protobuf/pyext/descriptor.h> -#include <google/protobuf/pyext/descriptor_pool.h> #include <google/protobuf/pyext/message.h> +#include <google/protobuf/pyext/message_factory.h> #include <google/protobuf/pyext/repeated_composite_container.h> #include <google/protobuf/pyext/repeated_scalar_container.h> #include <google/protobuf/pyext/scoped_pyobject_ptr.h> @@ -60,35 +60,6 @@ PyObject* len(ExtensionDict* self) { #endif } -// TODO(tibell): Use VisitCompositeField. -int ReleaseExtension(ExtensionDict* self, - PyObject* extension, - const FieldDescriptor* descriptor) { - if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { - if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - if (repeated_composite_container::Release( - reinterpret_cast<RepeatedCompositeContainer*>( - extension)) < 0) { - return -1; - } - } else { - if (repeated_scalar_container::Release( - reinterpret_cast<RepeatedScalarContainer*>( - extension)) < 0) { - return -1; - } - } - } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - if (cmessage::ReleaseSubMessage( - self->parent, descriptor, - reinterpret_cast<CMessage*>(extension)) < 0) { - return -1; - } - } - - return 0; -} - PyObject* subscript(ExtensionDict* self, PyObject* key) { const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key); if (descriptor == NULL) { @@ -130,8 +101,8 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) { if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - CMessageClass* message_class = cdescriptor_pool::GetMessageClass( - cmessage::GetDescriptorPoolForMessage(self->parent), + CMessageClass* message_class = message_factory::GetMessageClass( + cmessage::GetFactoryForMessage(self->parent), descriptor->message_type()); if (message_class == NULL) { return NULL; @@ -183,47 +154,6 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) { return 0; } -PyObject* ClearExtension(ExtensionDict* self, PyObject* extension) { - const FieldDescriptor* descriptor = - cmessage::GetExtensionDescriptor(extension); - if (descriptor == NULL) { - return NULL; - } - PyObject* value = PyDict_GetItem(self->values, extension); - if (self->parent) { - if (value != NULL) { - if (ReleaseExtension(self, value, descriptor) < 0) { - return NULL; - } - } - if (ScopedPyObjectPtr(cmessage::ClearFieldByDescriptor( - self->parent, descriptor)) == NULL) { - return NULL; - } - } - if (PyDict_DelItem(self->values, extension) < 0) { - PyErr_Clear(); - } - Py_RETURN_NONE; -} - -PyObject* HasExtension(ExtensionDict* self, PyObject* extension) { - const FieldDescriptor* descriptor = - cmessage::GetExtensionDescriptor(extension); - if (descriptor == NULL) { - return NULL; - } - if (self->parent) { - return cmessage::HasFieldByDescriptor(self->parent, descriptor); - } else { - int exists = PyDict_Contains(self->values, extension); - if (exists < 0) { - return NULL; - } - return PyBool_FromLong(exists); - } -} - PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name) { ScopedPyObjectPtr extensions_by_name(PyObject_GetAttrString( reinterpret_cast<PyObject*>(self->parent), "_extensions_by_name")); @@ -282,8 +212,6 @@ static PyMappingMethods MpMethods = { #define EDMETHOD(name, args, doc) { #name, (PyCFunction)name, args, doc } static PyMethodDef Methods[] = { - EDMETHOD(ClearExtension, METH_O, "Clears an extension from the object."), - EDMETHOD(HasExtension, METH_O, "Checks if the object has an extension."), EDMETHOD(_FindExtensionByName, METH_O, "Finds an extension by name."), EDMETHOD(_FindExtensionByNumber, METH_O, diff --git a/python/google/protobuf/pyext/extension_dict.h b/python/google/protobuf/pyext/extension_dict.h index 2456eda1..65b87862 100644 --- a/python/google/protobuf/pyext/extension_dict.h +++ b/python/google/protobuf/pyext/extension_dict.h @@ -86,49 +86,6 @@ namespace extension_dict { // Builds an Extensions dict for a specific message. ExtensionDict* NewExtensionDict(CMessage *parent); -// Gets the number of extension values in this ExtensionDict as a python object. -// -// Returns a new reference. -PyObject* len(ExtensionDict* self); - -// Releases extensions referenced outside this dictionary to keep outside -// references alive. -// -// Returns 0 on success, -1 on failure. -int ReleaseExtension(ExtensionDict* self, - PyObject* extension, - const FieldDescriptor* descriptor); - -// Gets an extension from the dict for the given extension descriptor. -// -// Returns a new reference. -PyObject* subscript(ExtensionDict* self, PyObject* key); - -// Assigns a value to an extension in the dict. Can only be used for singular -// simple types. -// -// Returns 0 on success, -1 on failure. -int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value); - -// Clears an extension from the dict. Will release the extension if there -// is still an external reference left to it. -// -// Returns None on success. -PyObject* ClearExtension(ExtensionDict* self, - PyObject* extension); - -// Gets an extension from the dict given the extension name as opposed to -// descriptor. -// -// Returns a new reference. -PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name); - -// Gets an extension from the dict given the extension field number as -// opposed to descriptor. -// -// Returns a new reference. -PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* number); - } // namespace extension_dict } // namespace python } // namespace protobuf diff --git a/python/google/protobuf/pyext/map_container.cc b/python/google/protobuf/pyext/map_container.cc index 0987b898..318c2e7c 100644 --- a/python/google/protobuf/pyext/map_container.cc +++ b/python/google/protobuf/pyext/map_container.cc @@ -42,7 +42,9 @@ #include <google/protobuf/map_field.h> #include <google/protobuf/map.h> #include <google/protobuf/message.h> +#include <google/protobuf/pyext/message_factory.h> #include <google/protobuf/pyext/message.h> +#include <google/protobuf/pyext/repeated_composite_container.h> #include <google/protobuf/pyext/scoped_pyobject_ptr.h> #if PY_MAJOR_VERSION >= 3 @@ -328,6 +330,15 @@ PyObject* Clear(PyObject* _self) { Py_RETURN_NONE; } +PyObject* GetEntryClass(PyObject* _self) { + MapContainer* self = GetMap(_self); + CMessageClass* message_class = message_factory::GetMessageClass( + cmessage::GetFactoryForMessage(self->parent), + self->parent_field_descriptor->message_type()); + Py_XINCREF(message_class); + return reinterpret_cast<PyObject*>(message_class); +} + PyObject* MapReflectionFriend::Contains(PyObject* _self, PyObject* key) { MapContainer* self = GetMap(_self); @@ -400,12 +411,7 @@ PyObject *NewScalarMapContainer( return NULL; } -#if PY_MAJOR_VERSION >= 3 - ScopedPyObjectPtr obj(PyType_GenericAlloc( - reinterpret_cast<PyTypeObject *>(ScalarMapContainer_Type), 0)); -#else - ScopedPyObjectPtr obj(PyType_GenericAlloc(&ScalarMapContainer_Type, 0)); -#endif + ScopedPyObjectPtr obj(PyType_GenericAlloc(ScalarMapContainer_Type, 0)); if (obj.get() == NULL) { return PyErr_Format(PyExc_RuntimeError, "Could not allocate new container."); @@ -527,6 +533,8 @@ static PyMethodDef ScalarMapMethods[] = { "Removes all elements from the map." }, { "get", ScalarMapGet, METH_VARARGS, "Gets the value for the given key if present, or otherwise a default" }, + { "GetEntryClass", (PyCFunction)GetEntryClass, METH_NOARGS, + "Return the class used to build Entries of (key, value) pairs." }, /* { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS, "Makes a deep copy of the class." }, @@ -536,6 +544,7 @@ static PyMethodDef ScalarMapMethods[] = { {NULL, NULL}, }; +PyTypeObject *ScalarMapContainer_Type; #if PY_MAJOR_VERSION >= 3 static PyType_Slot ScalarMapContainer_Type_slots[] = { {Py_tp_dealloc, (void *)ScalarMapDealloc}, @@ -554,7 +563,6 @@ static PyMethodDef ScalarMapMethods[] = { Py_TPFLAGS_DEFAULT, ScalarMapContainer_Type_slots }; - PyObject *ScalarMapContainer_Type; #else static PyMappingMethods ScalarMapMappingMethods = { MapReflectionFriend::Length, // mp_length @@ -562,7 +570,7 @@ static PyMethodDef ScalarMapMethods[] = { MapReflectionFriend::ScalarMapSetItem, // mp_ass_subscript }; - PyTypeObject ScalarMapContainer_Type = { + PyTypeObject _ScalarMapContainer_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME ".ScalarMapContainer", // tp_name sizeof(MapContainer), // tp_basicsize @@ -643,12 +651,7 @@ PyObject* NewMessageMapContainer( return NULL; } -#if PY_MAJOR_VERSION >= 3 - PyObject* obj = PyType_GenericAlloc( - reinterpret_cast<PyTypeObject *>(MessageMapContainer_Type), 0); -#else - PyObject* obj = PyType_GenericAlloc(&MessageMapContainer_Type, 0); -#endif + PyObject* obj = PyType_GenericAlloc(MessageMapContainer_Type, 0); if (obj == NULL) { return PyErr_Format(PyExc_RuntimeError, "Could not allocate new container."); @@ -780,6 +783,8 @@ static PyMethodDef MessageMapMethods[] = { "Gets the value for the given key if present, or otherwise a default" }, { "get_or_create", MapReflectionFriend::MessageMapGetItem, METH_O, "Alias for getitem, useful to make explicit that the map is mutated." }, + { "GetEntryClass", (PyCFunction)GetEntryClass, METH_NOARGS, + "Return the class used to build Entries of (key, value) pairs." }, /* { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS, "Makes a deep copy of the class." }, @@ -789,6 +794,7 @@ static PyMethodDef MessageMapMethods[] = { {NULL, NULL}, }; +PyTypeObject *MessageMapContainer_Type; #if PY_MAJOR_VERSION >= 3 static PyType_Slot MessageMapContainer_Type_slots[] = { {Py_tp_dealloc, (void *)MessageMapDealloc}, @@ -807,8 +813,6 @@ static PyMethodDef MessageMapMethods[] = { Py_TPFLAGS_DEFAULT, MessageMapContainer_Type_slots }; - - PyObject *MessageMapContainer_Type; #else static PyMappingMethods MessageMapMappingMethods = { MapReflectionFriend::Length, // mp_length @@ -816,7 +820,7 @@ static PyMethodDef MessageMapMethods[] = { MapReflectionFriend::MessageMapSetItem, // mp_ass_subscript }; - PyTypeObject MessageMapContainer_Type = { + PyTypeObject _MessageMapContainer_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME ".MessageMapContainer", // tp_name sizeof(MessageMapContainer), // tp_basicsize @@ -965,6 +969,63 @@ PyTypeObject MapIterator_Type = { 0, // tp_init }; +bool InitMapContainers() { + // ScalarMapContainer_Type derives from our MutableMapping type. + ScopedPyObjectPtr containers(PyImport_ImportModule( + "google.protobuf.internal.containers")); + if (containers == NULL) { + return false; + } + + ScopedPyObjectPtr mutable_mapping( + PyObject_GetAttrString(containers.get(), "MutableMapping")); + if (mutable_mapping == NULL) { + return false; + } + + if (!PyObject_TypeCheck(mutable_mapping.get(), &PyType_Type)) { + return false; + } + + Py_INCREF(mutable_mapping.get()); +#if PY_MAJOR_VERSION >= 3 + PyObject* bases = PyTuple_New(1); + PyTuple_SET_ITEM(bases, 0, mutable_mapping.get()); + + ScalarMapContainer_Type = reinterpret_cast<PyTypeObject*>( + PyType_FromSpecWithBases(&ScalarMapContainer_Type_spec, bases)); +#else + _ScalarMapContainer_Type.tp_base = + reinterpret_cast<PyTypeObject*>(mutable_mapping.get()); + + if (PyType_Ready(&_ScalarMapContainer_Type) < 0) { + return false; + } + + ScalarMapContainer_Type = &_ScalarMapContainer_Type; +#endif + + if (PyType_Ready(&MapIterator_Type) < 0) { + return false; + } + +#if PY_MAJOR_VERSION >= 3 + MessageMapContainer_Type = reinterpret_cast<PyTypeObject*>( + PyType_FromSpecWithBases(&MessageMapContainer_Type_spec, bases)); +#else + Py_INCREF(mutable_mapping.get()); + _MessageMapContainer_Type.tp_base = + reinterpret_cast<PyTypeObject*>(mutable_mapping.get()); + + if (PyType_Ready(&_MessageMapContainer_Type) < 0) { + return false; + } + + MessageMapContainer_Type = &_MessageMapContainer_Type; +#endif + return true; +} + } // namespace python } // namespace protobuf } // namespace google diff --git a/python/google/protobuf/pyext/map_container.h b/python/google/protobuf/pyext/map_container.h index fbd6713f..615657b0 100644 --- a/python/google/protobuf/pyext/map_container.h +++ b/python/google/protobuf/pyext/map_container.h @@ -112,16 +112,10 @@ struct MessageMapContainer : public MapContainer { PyObject* message_dict; }; -#if PY_MAJOR_VERSION >= 3 - extern PyObject *MessageMapContainer_Type; - extern PyType_Spec MessageMapContainer_Type_spec; - extern PyObject *ScalarMapContainer_Type; - extern PyType_Spec ScalarMapContainer_Type_spec; -#else - extern PyTypeObject MessageMapContainer_Type; - extern PyTypeObject ScalarMapContainer_Type; -#endif +bool InitMapContainers(); +extern PyTypeObject* MessageMapContainer_Type; +extern PyTypeObject* ScalarMapContainer_Type; extern PyTypeObject MapIterator_Type; // Both map types use the same iterator. // Builds a MapContainer object, from a parent message and a diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc index 5535338d..1b325469 100644 --- a/python/google/protobuf/pyext/message.cc +++ b/python/google/protobuf/pyext/message.cc @@ -63,6 +63,7 @@ #include <google/protobuf/pyext/repeated_composite_container.h> #include <google/protobuf/pyext/repeated_scalar_container.h> #include <google/protobuf/pyext/map_container.h> +#include <google/protobuf/pyext/message_factory.h> #include <google/protobuf/pyext/scoped_pyobject_ptr.h> #include <google/protobuf/stubs/strutil.h> @@ -244,6 +245,12 @@ static PyObject* New(PyTypeObject* type, return NULL; } + // Messages have no __dict__ + ScopedPyObjectPtr slots(PyTuple_New(0)); + if (PyDict_SetItemString(dict, "__slots__", slots.get()) < 0) { + return NULL; + } + // Build the arguments to the base metaclass. // We change the __bases__ classes. ScopedPyObjectPtr new_args; @@ -300,16 +307,19 @@ static PyObject* New(PyTypeObject* type, newtype->message_descriptor = descriptor; // TODO(amauryfa): Don't always use the canonical pool of the descriptor, // use the MessageFactory optionally passed in the class dict. - newtype->py_descriptor_pool = GetDescriptorPool_FromPool( - descriptor->file()->pool()); - if (newtype->py_descriptor_pool == NULL) { + PyDescriptorPool* py_descriptor_pool = + GetDescriptorPool_FromPool(descriptor->file()->pool()); + if (py_descriptor_pool == NULL) { return NULL; } - Py_INCREF(newtype->py_descriptor_pool); + newtype->py_message_factory = py_descriptor_pool->py_message_factory; + Py_INCREF(newtype->py_message_factory); - // Add the message to the DescriptorPool. - if (cdescriptor_pool::RegisterMessageClass(newtype->py_descriptor_pool, - descriptor, newtype) < 0) { + // Register the message in the MessageFactory. + // TODO(amauryfa): Move this call to MessageFactory.GetPrototype() when the + // MessageFactory is fully implemented in C++. + if (message_factory::RegisterMessageClass(newtype->py_message_factory, + descriptor, newtype) < 0) { return NULL; } @@ -321,8 +331,8 @@ static PyObject* New(PyTypeObject* type, } static void Dealloc(CMessageClass *self) { - Py_DECREF(self->py_message_descriptor); - Py_DECREF(self->py_descriptor_pool); + Py_XDECREF(self->py_message_descriptor); + Py_XDECREF(self->py_message_factory); Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self)); } @@ -752,15 +762,9 @@ bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor, namespace cmessage { -PyDescriptorPool* GetDescriptorPoolForMessage(CMessage* message) { - // No need to check the type: the type of instances of CMessage is always - // an instance of CMessageClass. Let's prove it with a debug-only check. +PyMessageFactory* GetFactoryForMessage(CMessage* message) { GOOGLE_DCHECK(PyObject_TypeCheck(message, &CMessage_Type)); - return reinterpret_cast<CMessageClass*>(Py_TYPE(message))->py_descriptor_pool; -} - -MessageFactory* GetFactoryForMessage(CMessage* message) { - return GetDescriptorPoolForMessage(message)->message_factory; + return reinterpret_cast<CMessageClass*>(Py_TYPE(message))->py_message_factory; } static int MaybeReleaseOverlappingOneofField( @@ -813,7 +817,8 @@ static Message* GetMutableMessage( return NULL; } return reflection->MutableMessage( - parent_message, parent_field, GetFactoryForMessage(parent)); + parent_message, parent_field, + GetFactoryForMessage(parent)->message_factory); } struct FixupMessageReference : public ChildVisitor { @@ -1172,6 +1177,8 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) { } CMessage* cmessage = reinterpret_cast<CMessage*>(message.get()); if (PyDict_Check(value)) { + // Make the message exist even if the dict is empty. + AssureWritable(cmessage); if (InitAttributes(cmessage, NULL, value) < 0) { return -1; } @@ -1231,7 +1238,7 @@ static PyObject* New(PyTypeObject* cls, if (message_descriptor == NULL) { return NULL; } - const Message* default_message = type->py_descriptor_pool->message_factory + const Message* default_message = type->py_message_factory->message_factory ->GetPrototype(message_descriptor); if (default_message == NULL) { PyErr_SetString(PyExc_TypeError, message_descriptor->full_name().c_str()); @@ -1292,6 +1299,9 @@ struct ClearWeakReferences : public ChildVisitor { }; static void Dealloc(CMessage* self) { + if (self->weakreflist) { + PyObject_ClearWeakRefs(reinterpret_cast<PyObject*>(self)); + } // Null out all weak references from children to this message. GOOGLE_CHECK_EQ(0, ForEachCompositeField(self, ClearWeakReferences())); if (self->extensions) { @@ -1459,18 +1469,20 @@ PyObject* HasField(CMessage* self, PyObject* arg) { } PyObject* ClearExtension(CMessage* self, PyObject* extension) { + const FieldDescriptor* descriptor = GetExtensionDescriptor(extension); + if (descriptor == NULL) { + return NULL; + } if (self->extensions != NULL) { - return extension_dict::ClearExtension(self->extensions, extension); - } else { - const FieldDescriptor* descriptor = GetExtensionDescriptor(extension); - if (descriptor == NULL) { - return NULL; - } - if (ScopedPyObjectPtr(ClearFieldByDescriptor(self, descriptor)) == NULL) { - return NULL; + PyObject* value = PyDict_GetItem(self->extensions->values, extension); + if (value != NULL) { + if (InternalReleaseFieldByDescriptor(self, descriptor, value) < 0) { + return NULL; + } + PyDict_DelItem(self->extensions->values, extension); } } - Py_RETURN_NONE; + return ClearFieldByDescriptor(self, descriptor); } PyObject* HasExtension(CMessage* self, PyObject* extension) { @@ -1556,7 +1568,7 @@ int SetOwner(CMessage* self, const shared_ptr<Message>& new_owner) { Message* ReleaseMessage(CMessage* self, const Descriptor* descriptor, const FieldDescriptor* field_descriptor) { - MessageFactory* message_factory = GetFactoryForMessage(self); + MessageFactory* message_factory = GetFactoryForMessage(self)->message_factory; Message* released_message = self->message->GetReflection()->ReleaseMessage( self->message, field_descriptor, message_factory); // ReleaseMessage will return NULL which differs from @@ -1624,12 +1636,19 @@ int InternalReleaseFieldByDescriptor( PyObject* ClearFieldByDescriptor( CMessage* self, - const FieldDescriptor* descriptor) { - if (!CheckFieldBelongsToMessage(descriptor, self->message)) { + const FieldDescriptor* field_descriptor) { + if (!CheckFieldBelongsToMessage(field_descriptor, self->message)) { return NULL; } AssureWritable(self); - self->message->GetReflection()->ClearField(self->message, descriptor); + Message* message = self->message; + message->GetReflection()->ClearField(message, field_descriptor); + if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM && + !message->GetReflection()->SupportsUnknownEnumValues()) { + UnknownFieldSet* unknown_field_set = + message->GetReflection()->MutableUnknownFields(message); + unknown_field_set->DeleteByNumber(field_descriptor->number()); + } Py_RETURN_NONE; } @@ -1665,27 +1684,17 @@ PyObject* ClearField(CMessage* self, PyObject* arg) { arg = arg_in_oneof.get(); } - PyObject* composite_field = self->composite_fields ? - PyDict_GetItem(self->composite_fields, arg) : NULL; - - // Only release the field if there's a possibility that there are - // references to it. - if (composite_field != NULL) { - if (InternalReleaseFieldByDescriptor(self, field_descriptor, - composite_field) < 0) { - return NULL; + // Release the field if it exists in the dict of composite fields. + if (self->composite_fields) { + PyObject* value = PyDict_GetItem(self->composite_fields, arg); + if (value != NULL) { + if (InternalReleaseFieldByDescriptor(self, field_descriptor, value) < 0) { + return NULL; + } + PyDict_DelItem(self->composite_fields, arg); } - PyDict_DelItem(self->composite_fields, arg); - } - message->GetReflection()->ClearField(message, field_descriptor); - if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM && - !message->GetReflection()->SupportsUnknownEnumValues()) { - UnknownFieldSet* unknown_field_set = - message->GetReflection()->MutableUnknownFields(message); - unknown_field_set->DeleteByNumber(field_descriptor->number()); } - - Py_RETURN_NONE; + return ClearFieldByDescriptor(self, field_descriptor); } PyObject* Clear(CMessage* self) { @@ -1927,8 +1936,8 @@ static PyObject* MergeFromString(CMessage* self, PyObject* arg) { if (allow_oversize_protos) { input.SetTotalBytesLimit(INT_MAX, INT_MAX); } - PyDescriptorPool* pool = GetDescriptorPoolForMessage(self); - input.SetExtensionRegistry(pool->pool, pool->message_factory); + PyMessageFactory* factory = GetFactoryForMessage(self); + input.SetExtensionRegistry(factory->pool->pool, factory->message_factory); bool success = self->message->MergePartialFromCodedStream(&input); if (success) { return PyInt_FromLong(input.CurrentPosition()); @@ -2108,8 +2117,8 @@ static PyObject* ListFields(CMessage* self) { // is no message class and we cannot retrieve the value. // TODO(amauryfa): consider building the class on the fly! if (fields[i]->message_type() != NULL && - cdescriptor_pool::GetMessageClass( - GetDescriptorPoolForMessage(self), + message_factory::GetMessageClass( + GetFactoryForMessage(self), fields[i]->message_type()) == NULL) { PyErr_Clear(); continue; @@ -2306,12 +2315,12 @@ PyObject* InternalGetScalar(const Message* message, PyObject* InternalGetSubMessage( CMessage* self, const FieldDescriptor* field_descriptor) { const Reflection* reflection = self->message->GetReflection(); - PyDescriptorPool* pool = GetDescriptorPoolForMessage(self); + PyMessageFactory* factory = GetFactoryForMessage(self); const Message& sub_message = reflection->GetMessage( - *self->message, field_descriptor, pool->message_factory); + *self->message, field_descriptor, factory->message_factory); - CMessageClass* message_class = cdescriptor_pool::GetMessageClass( - pool, field_descriptor->message_type()); + CMessageClass* message_class = message_factory::GetMessageClass( + factory, field_descriptor->message_type()); if (message_class == NULL) { return NULL; } @@ -2656,8 +2665,8 @@ PyObject* GetAttr(CMessage* self, PyObject* name) { const Descriptor* entry_type = field_descriptor->message_type(); const FieldDescriptor* value_type = entry_type->FindFieldByName("value"); if (value_type->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - CMessageClass* value_class = cdescriptor_pool::GetMessageClass( - GetDescriptorPoolForMessage(self), value_type->message_type()); + CMessageClass* value_class = message_factory::GetMessageClass( + GetFactoryForMessage(self), value_type->message_type()); if (value_class == NULL) { return NULL; } @@ -2679,8 +2688,8 @@ PyObject* GetAttr(CMessage* self, PyObject* name) { if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) { PyObject* py_container = NULL; if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - CMessageClass* message_class = cdescriptor_pool::GetMessageClass( - GetDescriptorPoolForMessage(self), field_descriptor->message_type()); + CMessageClass* message_class = message_factory::GetMessageClass( + GetFactoryForMessage(self), field_descriptor->message_type()); if (message_class == NULL) { return NULL; } @@ -2775,7 +2784,7 @@ PyTypeObject CMessage_Type = { 0, // tp_traverse 0, // tp_clear (richcmpfunc)cmessage::RichCompare, // tp_richcompare - 0, // tp_weaklistoffset + offsetof(CMessage, weakreflist), // tp_weaklistoffset 0, // tp_iter 0, // tp_iternext cmessage::Methods, // tp_methods @@ -2863,6 +2872,11 @@ bool InitProto2MessageModule(PyObject *m) { return false; } + // Initialize types and globals in message_factory.cc + if (!InitMessageFactory()) { + return false; + } + // Initialize constants defined in this file. InitGlobals(); @@ -2944,69 +2958,15 @@ bool InitProto2MessageModule(PyObject *m) { } // Initialize Map container types. - { - // ScalarMapContainer_Type derives from our MutableMapping type. - ScopedPyObjectPtr containers(PyImport_ImportModule( - "google.protobuf.internal.containers")); - if (containers == NULL) { - return false; - } - - ScopedPyObjectPtr mutable_mapping( - PyObject_GetAttrString(containers.get(), "MutableMapping")); - if (mutable_mapping == NULL) { - return false; - } - - if (!PyObject_TypeCheck(mutable_mapping.get(), &PyType_Type)) { - return false; - } - - Py_INCREF(mutable_mapping.get()); -#if PY_MAJOR_VERSION >= 3 - PyObject* bases = PyTuple_New(1); - PyTuple_SET_ITEM(bases, 0, mutable_mapping.get()); - - ScalarMapContainer_Type = - PyType_FromSpecWithBases(&ScalarMapContainer_Type_spec, bases); - PyModule_AddObject(m, "ScalarMapContainer", ScalarMapContainer_Type); -#else - ScalarMapContainer_Type.tp_base = - reinterpret_cast<PyTypeObject*>(mutable_mapping.get()); - - if (PyType_Ready(&ScalarMapContainer_Type) < 0) { - return false; - } - - PyModule_AddObject(m, "ScalarMapContainer", - reinterpret_cast<PyObject*>(&ScalarMapContainer_Type)); -#endif - - if (PyType_Ready(&MapIterator_Type) < 0) { - return false; - } - - PyModule_AddObject(m, "MapIterator", - reinterpret_cast<PyObject*>(&MapIterator_Type)); - - -#if PY_MAJOR_VERSION >= 3 - MessageMapContainer_Type = - PyType_FromSpecWithBases(&MessageMapContainer_Type_spec, bases); - PyModule_AddObject(m, "MessageMapContainer", MessageMapContainer_Type); -#else - Py_INCREF(mutable_mapping.get()); - MessageMapContainer_Type.tp_base = - reinterpret_cast<PyTypeObject*>(mutable_mapping.get()); - - if (PyType_Ready(&MessageMapContainer_Type) < 0) { - return false; - } - - PyModule_AddObject(m, "MessageMapContainer", - reinterpret_cast<PyObject*>(&MessageMapContainer_Type)); -#endif + if (!InitMapContainers()) { + return false; } + PyModule_AddObject(m, "ScalarMapContainer", + reinterpret_cast<PyObject*>(ScalarMapContainer_Type)); + PyModule_AddObject(m, "MessageMapContainer", + reinterpret_cast<PyObject*>(MessageMapContainer_Type)); + PyModule_AddObject(m, "MapIterator", + reinterpret_cast<PyObject*>(&MapIterator_Type)); if (PyType_Ready(&ExtensionDict_Type) < 0) { return false; diff --git a/python/google/protobuf/pyext/message.h b/python/google/protobuf/pyext/message.h index c44a2ae2..1550724c 100644 --- a/python/google/protobuf/pyext/message.h +++ b/python/google/protobuf/pyext/message.h @@ -62,7 +62,7 @@ using internal::shared_ptr; namespace python { struct ExtensionDict; -struct PyDescriptorPool; +struct PyMessageFactory; typedef struct CMessage { PyObject_HEAD; @@ -112,6 +112,9 @@ typedef struct CMessage { // Similar to composite_fields, acting as a cache, but also contains the // required extension dict logic. ExtensionDict* extensions; + + // Implements the "weakref" protocol for this object. + PyObject* weakreflist; } CMessage; extern PyTypeObject CMessage_Type; @@ -132,14 +135,11 @@ struct CMessageClass { // Owned reference, used to keep the pointer above alive. PyObject* py_message_descriptor; - // The Python DescriptorPool used to create the class. It is needed to resolve + // The Python MessageFactory used to create the class. It is needed to resolve // fields descriptors, including extensions fields; its C++ MessageFactory is // used to instantiate submessages. - // This can be different from DESCRIPTOR.file.pool, in the case of a custom - // DescriptorPool which defines new extensions. - // We own the reference, because it's important to keep the descriptors and - // factory alive. - PyDescriptorPool* py_descriptor_pool; + // We own the reference, because it's important to keep the factory alive. + PyMessageFactory* py_message_factory; PyObject* AsPyObject() { return reinterpret_cast<PyObject*>(this); @@ -154,14 +154,6 @@ namespace cmessage { // The caller must fill self->message, self->owner and eventually self->parent. CMessage* NewEmptyMessage(CMessageClass* type); -// Release a submessage from its proto tree, making it a new top-level messgae. -// A new message will be created if this is a read-only default instance. -// -// Corresponds to reflection api method ReleaseMessage. -int ReleaseSubMessage(CMessage* self, - const FieldDescriptor* field_descriptor, - CMessage* child_cmessage); - // Retrieves the C++ descriptor of a Python Extension descriptor. // On error, return NULL with an exception set. const FieldDescriptor* GetExtensionDescriptor(PyObject* extension); @@ -262,14 +254,13 @@ int SetOwner(CMessage* self, const shared_ptr<Message>& new_owner); int AssureWritable(CMessage* self); -// Returns the "best" DescriptorPool for the given message. -// This is often equivalent to message.DESCRIPTOR.pool, but not always, when -// the message class was created from a MessageFactory using a custom pool which -// uses the generated pool as an underlay. +// Returns the message factory for the given message. +// This is equivalent to message.MESSAGE_FACTORY // -// The returned pool is suitable for finding fields and building submessages, +// The returned factory is suitable for finding fields and building submessages, // even in the case of extensions. -PyDescriptorPool* GetDescriptorPoolForMessage(CMessage* message); +// Returns a *borrowed* reference, and never fails because we pass a CMessage. +PyMessageFactory* GetFactoryForMessage(CMessage* message); PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg); diff --git a/python/google/protobuf/pyext/message_factory.cc b/python/google/protobuf/pyext/message_factory.cc new file mode 100644 index 00000000..2ad89022 --- /dev/null +++ b/python/google/protobuf/pyext/message_factory.cc @@ -0,0 +1,214 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include <Python.h> + +#include <google/protobuf/dynamic_message.h> +#include <google/protobuf/pyext/descriptor.h> +#include <google/protobuf/pyext/message.h> +#include <google/protobuf/pyext/message_factory.h> +#include <google/protobuf/pyext/scoped_pyobject_ptr.h> + +#if PY_MAJOR_VERSION >= 3 + #if PY_VERSION_HEX < 0x03030000 + #error "Python 3.0 - 3.2 are not supported." + #endif + #define PyString_AsStringAndSize(ob, charpp, sizep) \ + (PyUnicode_Check(ob)? \ + ((*(charpp) = PyUnicode_AsUTF8AndSize(ob, (sizep))) == NULL? -1: 0): \ + PyBytes_AsStringAndSize(ob, (charpp), (sizep))) +#endif + +namespace google { +namespace protobuf { +namespace python { + +namespace message_factory { + +PyMessageFactory* NewMessageFactory(PyTypeObject* type, PyDescriptorPool* pool) { + PyMessageFactory* factory = reinterpret_cast<PyMessageFactory*>( + PyType_GenericAlloc(type, 0)); + if (factory == NULL) { + return NULL; + } + + DynamicMessageFactory* message_factory = new DynamicMessageFactory(); + // This option might be the default some day. + message_factory->SetDelegateToGeneratedFactory(true); + factory->message_factory = message_factory; + + factory->pool = pool; + // TODO(amauryfa): When the MessageFactory is not created from the + // DescriptorPool this reference should be owned, not borrowed. + // Py_INCREF(pool); + + factory->classes_by_descriptor = new PyMessageFactory::ClassesByMessageMap(); + + return factory; +} + +PyObject* New(PyTypeObject* type, PyObject* args, PyObject* kwargs) { + static char* kwlist[] = {"pool", 0}; + PyObject* pool = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O", kwlist, &pool)) { + return NULL; + } + ScopedPyObjectPtr owned_pool; + if (pool == NULL || pool == Py_None) { + owned_pool.reset(PyObject_CallFunction( + reinterpret_cast<PyObject*>(&PyDescriptorPool_Type), NULL)); + if (owned_pool == NULL) { + return NULL; + } + pool = owned_pool.get(); + } else { + if (!PyObject_TypeCheck(pool, &PyDescriptorPool_Type)) { + PyErr_Format(PyExc_TypeError, "Expected a DescriptorPool, got %s", + pool->ob_type->tp_name); + return NULL; + } + } + + return reinterpret_cast<PyObject*>( + NewMessageFactory(type, reinterpret_cast<PyDescriptorPool*>(pool))); +} + +static void Dealloc(PyMessageFactory* self) { + // TODO(amauryfa): When the MessageFactory is not created from the + // DescriptorPool this reference should be owned, not borrowed. + // Py_CLEAR(self->pool); + typedef PyMessageFactory::ClassesByMessageMap::iterator iterator; + for (iterator it = self->classes_by_descriptor->begin(); + it != self->classes_by_descriptor->end(); ++it) { + Py_DECREF(it->second); + } + delete self->classes_by_descriptor; + delete self->message_factory; + Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self)); +} + +// Add a message class to our database. +int RegisterMessageClass(PyMessageFactory* self, + const Descriptor* message_descriptor, + CMessageClass* message_class) { + Py_INCREF(message_class); + typedef PyMessageFactory::ClassesByMessageMap::iterator iterator; + std::pair<iterator, bool> ret = self->classes_by_descriptor->insert( + std::make_pair(message_descriptor, message_class)); + if (!ret.second) { + // Update case: DECREF the previous value. + Py_DECREF(ret.first->second); + ret.first->second = message_class; + } + return 0; +} + +// Retrieve the message class added to our database. +CMessageClass* GetMessageClass(PyMessageFactory* self, + const Descriptor* message_descriptor) { + typedef PyMessageFactory::ClassesByMessageMap::iterator iterator; + iterator ret = self->classes_by_descriptor->find(message_descriptor); + if (ret == self->classes_by_descriptor->end()) { + PyErr_Format(PyExc_TypeError, "No message class registered for '%s'", + message_descriptor->full_name().c_str()); + return NULL; + } else { + return ret->second; + } +} + +static PyMethodDef Methods[] = { + {NULL}}; + +static PyObject* GetPool(PyMessageFactory* self, void* closure) { + Py_INCREF(self->pool); + return reinterpret_cast<PyObject*>(self->pool); +} + +static PyGetSetDef Getters[] = { + {"pool", (getter)GetPool, NULL, "DescriptorPool"}, + {NULL} +}; + +} // namespace message_factory + +PyTypeObject PyMessageFactory_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME + ".MessageFactory", // tp_name + sizeof(PyMessageFactory), // tp_basicsize + 0, // tp_itemsize + (destructor)message_factory::Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, // tp_flags + "A static Message Factory", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + message_factory::Methods, // tp_methods + 0, // tp_members + message_factory::Getters, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init + 0, // tp_alloc + message_factory::New, // tp_new + PyObject_Del, // tp_free +}; + +bool InitMessageFactory() { + if (PyType_Ready(&PyMessageFactory_Type) < 0) { + return false; + } + + return true; +} + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/python/google/protobuf/pyext/message_factory.h b/python/google/protobuf/pyext/message_factory.h new file mode 100644 index 00000000..07cccbfb --- /dev/null +++ b/python/google/protobuf/pyext/message_factory.h @@ -0,0 +1,103 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_FACTORY_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_FACTORY_H__ + +#include <Python.h> + +#include <google/protobuf/stubs/hash.h> +#include <google/protobuf/descriptor.h> +#include <google/protobuf/pyext/descriptor_pool.h> + +namespace google { +namespace protobuf { +class MessageFactory; + +namespace python { + +// The (meta) type of all Messages classes. +struct CMessageClass; + +struct PyMessageFactory { + PyObject_HEAD + + // DynamicMessageFactory used to create C++ instances of messages. + // This object cache the descriptors that were used, so the DescriptorPool + // needs to get rid of it before it can delete itself. + // + // Note: A C++ MessageFactory is different from the PyMessageFactory. + // The C++ one creates messages, when the Python one creates classes. + MessageFactory* message_factory; + + // borrowed reference to a Python DescriptorPool. + // TODO(amauryfa): invert the dependency: the MessageFactory owns the + // DescriptorPool, not the opposite. + PyDescriptorPool* pool; + + // Make our own mapping to retrieve Python classes from C++ descriptors. + // + // Descriptor pointers stored here are owned by the DescriptorPool above. + // Python references to classes are owned by this PyDescriptorPool. + typedef hash_map<const Descriptor*, CMessageClass*> ClassesByMessageMap; + ClassesByMessageMap* classes_by_descriptor; +}; + +extern PyTypeObject PyMessageFactory_Type; + +namespace message_factory { + +// Creates a new MessageFactory instance. +PyMessageFactory* NewMessageFactory(PyTypeObject* type, PyDescriptorPool* pool); + +// Registers a new Python class for the given message descriptor. +// On error, returns -1 with a Python exception set. +int RegisterMessageClass(PyMessageFactory* self, + const Descriptor* message_descriptor, + CMessageClass* message_class); + +// Retrieves the Python class registered with the given message descriptor. +// +// Returns a *borrowed* reference if found, otherwise returns NULL with an +// exception set. +CMessageClass* GetMessageClass(PyMessageFactory* self, + const Descriptor* message_descriptor); + +} // namespace message_factory + +// Initialize objects used by this module. +// On error, returns false with a Python exception set. +bool InitMessageFactory(); + +} // namespace python +} // namespace protobuf + +} // namespace google +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_FACTORY_H__ diff --git a/python/google/protobuf/pyext/repeated_composite_container.cc b/python/google/protobuf/pyext/repeated_composite_container.cc index bb2f6db2..43a2bc12 100644 --- a/python/google/protobuf/pyext/repeated_composite_container.cc +++ b/python/google/protobuf/pyext/repeated_composite_container.cc @@ -364,7 +364,7 @@ static int SortPythonMessages(RepeatedCompositeContainer* self, ScopedPyObjectPtr m(PyObject_GetAttrString(self->child_messages, "sort")); if (m == NULL) return -1; - if (PyObject_Call(m.get(), args, kwds) == NULL) + if (ScopedPyObjectPtr(PyObject_Call(m.get(), args, kwds)) == NULL) return -1; if (self->message != NULL) { ReorderAttached(self); diff --git a/python/google/protobuf/symbol_database.py b/python/google/protobuf/symbol_database.py index aa466abd..ecbef211 100644 --- a/python/google/protobuf/symbol_database.py +++ b/python/google/protobuf/symbol_database.py @@ -129,7 +129,8 @@ class SymbolDatabase(message_factory.MessageFactory): Only messages already created and registered will be returned; (this is the case for imported _pb2 modules) - But unlike MessageFactory, this version also returns nested messages. + But unlike MessageFactory, this version also returns already defined nested + messages, but does not register any message extensions. Args: files: The file names to extract messages from. diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py index 06b79d77..90f6ce42 100755 --- a/python/google/protobuf/text_format.py +++ b/python/google/protobuf/text_format.py @@ -228,13 +228,13 @@ def _BuildMessageFromTypeName(type_name, descriptor_pool): wasn't found matching type_name. """ # pylint: disable=g-import-not-at-top - from google.protobuf import message_factory - factory = message_factory.MessageFactory(descriptor_pool) + from google.protobuf import symbol_database + database = symbol_database.Default() try: message_descriptor = descriptor_pool.FindMessageTypeByName(type_name) except KeyError: return None - message_type = factory.GetPrototype(message_descriptor) + message_type = database.GetPrototype(message_descriptor) return message_type() @@ -317,8 +317,7 @@ class _Printer(object): # of this file to work around. # # TODO(haberman): refactor and optimize if this becomes an issue. - entry_submsg = field.message_type._concrete_class(key=key, - value=value[key]) + entry_submsg = value.GetEntryClass()(key=key, value=value[key]) self.PrintField(field, entry_submsg) elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: for element in value: @@ -749,8 +748,7 @@ class _Parser(object): if field.is_extension: sub_message = message.Extensions[field].add() elif is_map_entry: - # pylint: disable=protected-access - sub_message = field.message_type._concrete_class() + sub_message = getattr(message, field.name).GetEntryClass()() else: sub_message = getattr(message, field.name).add() else: @@ -1448,9 +1446,9 @@ def ParseBool(text): Raises: ValueError: If text is not a valid boolean. """ - if text in ('true', 't', '1'): + if text in ('true', 't', '1', 'True'): return True - elif text in ('false', 'f', '0'): + elif text in ('false', 'f', '0', 'False'): return False else: raise ValueError('Expected "true" or "false".') |