From bde4a3254a7de58911941b0fbf38e9dd992de973 Mon Sep 17 00:00:00 2001 From: "jieluo@google.com" Date: Tue, 12 Aug 2014 21:10:30 +0000 Subject: down integrate python opensource to svn --- Makefile.am | 39 +- python/README.txt | 2 + python/ez_setup.py | 2 + python/google/protobuf/descriptor.py | 178 +- python/google/protobuf/descriptor_database.py | 19 +- python/google/protobuf/descriptor_pool.py | 394 +-- .../google/protobuf/internal/api_implementation.cc | 139 ++ .../google/protobuf/internal/api_implementation.py | 52 +- .../internal/api_implementation_default_test.py | 63 + python/google/protobuf/internal/containers.py | 22 +- python/google/protobuf/internal/cpp_message.py | 2 +- python/google/protobuf/internal/decoder.py | 157 +- .../protobuf/internal/descriptor_cpp2_test.py | 58 + .../protobuf/internal/descriptor_database_test.py | 16 +- .../protobuf/internal/descriptor_pool_test.py | 428 +++- .../protobuf/internal/descriptor_pool_test1.proto | 94 + .../protobuf/internal/descriptor_pool_test2.proto | 70 + .../protobuf/internal/descriptor_python_test.py | 54 + python/google/protobuf/internal/descriptor_test.py | 152 +- python/google/protobuf/internal/encoder.py | 43 +- .../google/protobuf/internal/factory_test1.proto | 2 + .../google/protobuf/internal/factory_test2.proto | 15 + python/google/protobuf/internal/generator_test.py | 82 +- .../google/protobuf/internal/message_cpp_test.py | 45 - .../protobuf/internal/message_factory_cpp2_test.py | 56 + .../internal/message_factory_python_test.py | 54 + .../protobuf/internal/message_factory_test.py | 52 +- .../protobuf/internal/message_python_test.py | 54 + python/google/protobuf/internal/message_test.py | 274 ++- .../protobuf/internal/missing_enum_values.proto | 50 + python/google/protobuf/internal/python_message.py | 151 +- .../internal/reflection_cpp2_generated_test.py | 94 + .../internal/reflection_cpp_generated_test.py | 91 - python/google/protobuf/internal/reflection_test.py | 419 +++- .../protobuf/internal/service_reflection_test.py | 6 +- .../protobuf/internal/symbol_database_test.py | 120 + python/google/protobuf/internal/test_util.py | 95 +- .../google/protobuf/internal/text_encoding_test.py | 68 + .../google/protobuf/internal/text_format_test.py | 417 ++-- python/google/protobuf/internal/type_checkers.py | 72 +- .../protobuf/internal/unknown_fields_test.py | 71 +- .../google/protobuf/internal/wire_format_test.py | 8 +- python/google/protobuf/message.py | 6 +- python/google/protobuf/message_factory.py | 110 +- python/google/protobuf/pyext/README | 6 + python/google/protobuf/pyext/cpp_message.py | 61 + python/google/protobuf/pyext/descriptor.cc | 357 +++ python/google/protobuf/pyext/descriptor.h | 96 + python/google/protobuf/pyext/extension_dict.cc | 338 +++ python/google/protobuf/pyext/extension_dict.h | 123 + python/google/protobuf/pyext/message.cc | 2561 ++++++++++++++++++++ python/google/protobuf/pyext/message.h | 305 +++ python/google/protobuf/pyext/proto2_api_test.proto | 38 + python/google/protobuf/pyext/python.proto | 66 + python/google/protobuf/pyext/python_protobuf.h | 57 + .../protobuf/pyext/repeated_composite_container.cc | 763 ++++++ .../protobuf/pyext/repeated_composite_container.h | 172 ++ .../protobuf/pyext/repeated_scalar_container.cc | 825 +++++++ .../protobuf/pyext/repeated_scalar_container.h | 112 + python/google/protobuf/pyext/scoped_pyobject_ptr.h | 95 + python/google/protobuf/reflection.py | 46 +- python/google/protobuf/symbol_database.py | 185 ++ python/google/protobuf/text_encoding.py | 110 + python/google/protobuf/text_format.py | 322 ++- python/setup.py | 149 +- src/Makefile.am | 4 + .../protobuf/compiler/python/python_generator.cc | 34 +- ...rmat_unittest_data_pointy_oneof_implemented.txt | 129 + 68 files changed, 10255 insertions(+), 1095 deletions(-) create mode 100644 python/google/protobuf/internal/api_implementation.cc create mode 100644 python/google/protobuf/internal/api_implementation_default_test.py create mode 100644 python/google/protobuf/internal/descriptor_cpp2_test.py create mode 100644 python/google/protobuf/internal/descriptor_pool_test1.proto create mode 100644 python/google/protobuf/internal/descriptor_pool_test2.proto create mode 100644 python/google/protobuf/internal/descriptor_python_test.py delete mode 100644 python/google/protobuf/internal/message_cpp_test.py create mode 100644 python/google/protobuf/internal/message_factory_cpp2_test.py create mode 100644 python/google/protobuf/internal/message_factory_python_test.py create mode 100644 python/google/protobuf/internal/message_python_test.py create mode 100644 python/google/protobuf/internal/missing_enum_values.proto create mode 100755 python/google/protobuf/internal/reflection_cpp2_generated_test.py delete mode 100755 python/google/protobuf/internal/reflection_cpp_generated_test.py create mode 100644 python/google/protobuf/internal/symbol_database_test.py create mode 100755 python/google/protobuf/internal/text_encoding_test.py create mode 100644 python/google/protobuf/pyext/README create mode 100644 python/google/protobuf/pyext/cpp_message.py create mode 100644 python/google/protobuf/pyext/descriptor.cc create mode 100644 python/google/protobuf/pyext/descriptor.h create mode 100644 python/google/protobuf/pyext/extension_dict.cc create mode 100644 python/google/protobuf/pyext/extension_dict.h create mode 100644 python/google/protobuf/pyext/message.cc create mode 100644 python/google/protobuf/pyext/message.h create mode 100644 python/google/protobuf/pyext/proto2_api_test.proto create mode 100644 python/google/protobuf/pyext/python.proto create mode 100644 python/google/protobuf/pyext/python_protobuf.h create mode 100644 python/google/protobuf/pyext/repeated_composite_container.cc create mode 100644 python/google/protobuf/pyext/repeated_composite_container.h create mode 100644 python/google/protobuf/pyext/repeated_scalar_container.cc create mode 100644 python/google/protobuf/pyext/repeated_scalar_container.h create mode 100644 python/google/protobuf/pyext/scoped_pyobject_ptr.h create mode 100644 python/google/protobuf/symbol_database.py create mode 100644 python/google/protobuf/text_encoding.py create mode 100644 src/google/protobuf/testdata/text_format_unittest_data_pointy_oneof_implemented.txt diff --git a/Makefile.am b/Makefile.am index 5246e548..ad8a4df9 100644 --- a/Makefile.am +++ b/Makefile.am @@ -168,37 +168,64 @@ EXTRA_DIST = \ java/src/test/java/com/google/protobuf/test_custom_options.proto \ java/pom.xml \ java/README.txt \ - python/google/protobuf/internal/generator_test.py \ + python/google/protobuf/internal/api_implementation.cc \ + python/google/protobuf/internal/api_implementation.py \ + python/google/protobuf/internal/api_implementation_default_test.py \ python/google/protobuf/internal/containers.py \ + python/google/protobuf/internal/cpp_message.py \ python/google/protobuf/internal/decoder.py \ + python/google/protobuf/internal/descriptor_cpp2_test.py \ python/google/protobuf/internal/descriptor_database_test.py \ python/google/protobuf/internal/descriptor_pool_test.py \ + python/google/protobuf/internal/descriptor_pool_test1.proto \ + python/google/protobuf/internal/descriptor_pool_test2.proto \ + python/google/protobuf/internal/descriptor_python_test.py \ python/google/protobuf/internal/descriptor_test.py \ python/google/protobuf/internal/encoder.py \ python/google/protobuf/internal/enum_type_wrapper.py \ python/google/protobuf/internal/factory_test1.proto \ python/google/protobuf/internal/factory_test2.proto \ - python/google/protobuf/internal/message_cpp_test.py \ + python/google/protobuf/internal/generator_test.py \ + python/google/protobuf/internal/message_factory_cpp2_test.py \ + python/google/protobuf/internal/message_factory_python_test.py \ python/google/protobuf/internal/message_factory_test.py \ python/google/protobuf/internal/message_listener.py \ + python/google/protobuf/internal/message_python_test.py \ python/google/protobuf/internal/message_test.py \ + python/google/protobuf/internal/missing_enum_values.proto \ python/google/protobuf/internal/more_extensions.proto \ python/google/protobuf/internal/more_extensions_dynamic.proto \ python/google/protobuf/internal/more_messages.proto \ python/google/protobuf/internal/python_message.py \ - python/google/protobuf/internal/cpp_message.py \ - python/google/protobuf/internal/api_implementation.py \ + python/google/protobuf/internal/reflection_cpp2_generated_test.py \ python/google/protobuf/internal/reflection_test.py \ - python/google/protobuf/internal/reflection_cpp_generated_test.py \ python/google/protobuf/internal/service_reflection_test.py \ python/google/protobuf/internal/test_bad_identifiers.proto \ python/google/protobuf/internal/test_util.py \ + python/google/protobuf/internal/text_encoding_test.py \ python/google/protobuf/internal/text_format_test.py \ python/google/protobuf/internal/type_checkers.py \ python/google/protobuf/internal/unknown_fields_test.py \ python/google/protobuf/internal/wire_format.py \ python/google/protobuf/internal/wire_format_test.py \ python/google/protobuf/internal/__init__.py \ + python/google/protobuf/pyext/README \ + python/google/protobuf/pyext/cpp_message.py \ + python/google/protobuf/pyext/descriptor.h \ + python/google/protobuf/pyext/descriptor.cc \ + python/google/protobuf/pyext/extension_dict.h \ + python/google/protobuf/pyext/extension_dict.cc \ + python/google/protobuf/pyext/message.h \ + python/google/protobuf/pyext/message.cc \ + python/google/protobuf/pyext/proto2_api_test.proto \ + python/google/protobuf/pyext/python.proto \ + python/google/protobuf/pyext/python_protobuf.h \ + python/google/protobuf/pyext/repeated_composite_container.h \ + python/google/protobuf/pyext/repeated_composite_container.cc \ + python/google/protobuf/pyext/repeated_scalar_container.h \ + python/google/protobuf/pyext/repeated_scalar_container.cc \ + python/google/protobuf/pyext/scoped_pyobject_ptr.h \ + python/google/protobuf/pyext/__init__.py \ python/google/protobuf/descriptor.py \ python/google/protobuf/descriptor_database.py \ python/google/protobuf/descriptor_pool.py \ @@ -207,6 +234,8 @@ EXTRA_DIST = \ python/google/protobuf/reflection.py \ python/google/protobuf/service.py \ python/google/protobuf/service_reflection.py \ + python/google/protobuf/symbol_database.py \ + python/google/protobuf/text_encoding.py \ python/google/protobuf/text_format.py \ python/google/protobuf/__init__.py \ python/google/__init__.py \ diff --git a/python/README.txt b/python/README.txt index bb0032fe..e052e895 100644 --- a/python/README.txt +++ b/python/README.txt @@ -47,6 +47,7 @@ Installation $ python setup.py build $ python setup.py test + $ python setup.py exttest # To test C++ implementation (see below). If some tests fail, this library may not work correctly on your system. Continue at your own risk. @@ -90,6 +91,7 @@ To use the C++ implementation, you need to: 2) Export an environment variable: $ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp + $ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION=2 You need to export this variable before running setup.py script to build and install the extension. You must also set the variable at runtime, otherwise diff --git a/python/ez_setup.py b/python/ez_setup.py index a2cf777d..71177b00 100755 --- a/python/ez_setup.py +++ b/python/ez_setup.py @@ -103,10 +103,12 @@ def use_setuptools( sys.path.insert(0, egg) import setuptools; setuptools.bootstrap_install_from = egg try: + return do_download() import pkg_resources except ImportError: return do_download() try: + return do_download() pkg_resources.require("setuptools>="+version); return except pkg_resources.VersionConflict, e: if was_imported: diff --git a/python/google/protobuf/descriptor.py b/python/google/protobuf/descriptor.py index eb13eda5..555498d5 100755 --- a/python/google/protobuf/descriptor.py +++ b/python/google/protobuf/descriptor.py @@ -28,19 +28,26 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# Needs to stay compatible with Python 2.5 due to GAE. +# +# Copyright 2007 Google Inc. All Rights Reserved. + """Descriptors essentially contain exactly the information found in a .proto file, in types that make this information accessible in Python. """ __author__ = 'robinson@google.com (Will Robinson)' - from google.protobuf.internal import api_implementation if api_implementation.Type() == 'cpp': + # Used by MakeDescriptor in cpp mode + import os + import uuid + if api_implementation.Version() == 2: - from google.protobuf.internal.cpp import _message + from google.protobuf.pyext import _message else: from google.protobuf.internal import cpp_message @@ -220,13 +227,21 @@ class Descriptor(_NestedDescriptorBase): options: (descriptor_pb2.MessageOptions) Protocol message options or None to use default message options. + oneofs: (list of OneofDescriptor) The list of descriptors for oneof fields + in this message. + oneofs_by_name: (dict str -> OneofDescriptor) Same objects as in |oneofs|, + but indexed by "name" attribute. + file: (FileDescriptor) Reference to file descriptor. """ + # NOTE(tmarek): The file argument redefining a builtin is nothing we can + # fix right now since we don't know how many clients already rely on the + # name of the argument. def __init__(self, name, full_name, filename, containing_type, fields, nested_types, enum_types, extensions, options=None, - is_extendable=True, extension_ranges=None, file=None, - serialized_start=None, serialized_end=None): + is_extendable=True, extension_ranges=None, oneofs=None, + file=None, serialized_start=None, serialized_end=None): # pylint:disable=redefined-builtin """Arguments to __init__() are as described in the description of Descriptor fields above. @@ -236,7 +251,7 @@ class Descriptor(_NestedDescriptorBase): super(Descriptor, self).__init__( options, 'MessageOptions', name, full_name, file, containing_type, serialized_start=serialized_start, - serialized_end=serialized_start) + serialized_end=serialized_end) # We have fields in addition to fields_by_name and fields_by_number, # so that: @@ -250,6 +265,8 @@ class Descriptor(_NestedDescriptorBase): self.fields_by_name = dict((f.name, f) for f in fields) self.nested_types = nested_types + for nested_type in nested_types: + nested_type.containing_type = self self.nested_types_by_name = dict((t.name, t) for t in nested_types) self.enum_types = enum_types @@ -265,9 +282,10 @@ class Descriptor(_NestedDescriptorBase): self.extensions_by_name = dict((f.name, f) for f in extensions) self.is_extendable = is_extendable self.extension_ranges = extension_ranges - - self._serialized_start = serialized_start - self._serialized_end = serialized_end + self.oneofs = oneofs if oneofs is not None else [] + self.oneofs_by_name = dict((o.name, o) for o in self.oneofs) + for oneof in self.oneofs: + oneof.containing_type = self def EnumValueName(self, enum, value): """Returns the string name of an enum value. @@ -353,6 +371,9 @@ class FieldDescriptor(DescriptorBase): options: (descriptor_pb2.FieldOptions) Protocol message field options or None to use default field options. + + containing_oneof: (OneofDescriptor) If the field is a member of a oneof + union, contains its descriptor. Otherwise, None. """ # Must be consistent with C++ FieldDescriptor::Type enum in @@ -425,10 +446,16 @@ class FieldDescriptor(DescriptorBase): LABEL_REPEATED = 3 MAX_LABEL = 3 + # Must be consistent with C++ constants kMaxNumber, kFirstReservedNumber, + # and kLastReservedNumber in descriptor.h + MAX_FIELD_NUMBER = (1 << 29) - 1 + FIRST_RESERVED_FIELD_NUMBER = 19000 + LAST_RESERVED_FIELD_NUMBER = 19999 + def __init__(self, name, full_name, index, number, type, cpp_type, label, default_value, message_type, enum_type, containing_type, is_extension, extension_scope, options=None, - has_default_value=True): + has_default_value=True, containing_oneof=None): """The arguments are as described in the description of FieldDescriptor attributes above. @@ -451,15 +478,21 @@ class FieldDescriptor(DescriptorBase): self.enum_type = enum_type self.is_extension = is_extension self.extension_scope = extension_scope + self.containing_oneof = containing_oneof if api_implementation.Type() == 'cpp': if is_extension: if api_implementation.Version() == 2: - self._cdescriptor = _message.GetExtensionDescriptor(full_name) + # pylint: disable=protected-access + self._cdescriptor = ( + _message.Message._GetExtensionDescriptor(full_name)) + # pylint: enable=protected-access else: self._cdescriptor = cpp_message.GetExtensionDescriptor(full_name) else: if api_implementation.Version() == 2: - self._cdescriptor = _message.GetFieldDescriptor(full_name) + # pylint: disable=protected-access + self._cdescriptor = _message.Message._GetFieldDescriptor(full_name) + # pylint: enable=protected-access else: self._cdescriptor = cpp_message.GetFieldDescriptor(full_name) else: @@ -522,7 +555,7 @@ class EnumDescriptor(_NestedDescriptorBase): super(EnumDescriptor, self).__init__( options, 'EnumOptions', name, full_name, file, containing_type, serialized_start=serialized_start, - serialized_end=serialized_start) + serialized_end=serialized_end) self.values = values for value in self.values: @@ -530,9 +563,6 @@ class EnumDescriptor(_NestedDescriptorBase): self.values_by_name = dict((v.name, v) for v in values) self.values_by_number = dict((v.number, v) for v in values) - self._serialized_start = serialized_start - self._serialized_end = serialized_end - def CopyToProto(self, proto): """Copies this to a descriptor_pb2.EnumDescriptorProto. @@ -567,6 +597,29 @@ class EnumValueDescriptor(DescriptorBase): self.type = type +class OneofDescriptor(object): + """Descriptor for a oneof field. + + name: (str) Name of the oneof field. + full_name: (str) Full name of the oneof field, including package name. + index: (int) 0-based index giving the order of the oneof field inside + its containing type. + containing_type: (Descriptor) Descriptor of the protocol message + type that contains this field. Set by the Descriptor constructor + if we're passed into one. + fields: (list of FieldDescriptor) The list of field descriptors this + oneof can contain. + """ + + def __init__(self, name, full_name, index, containing_type, fields): + """Arguments are as described in the attribute description above.""" + self.name = name + self.full_name = full_name + self.index = index + self.containing_type = containing_type + self.fields = fields + + class ServiceDescriptor(_NestedDescriptorBase): """Descriptor for a service. @@ -645,13 +698,22 @@ class MethodDescriptor(DescriptorBase): class FileDescriptor(DescriptorBase): """Descriptor for a file. Mimics the descriptor_pb2.FileDescriptorProto. + Note that enum_types_by_name, extensions_by_name, and dependencies + fields are only set by the message_factory module, and not by the + generated proto code. + name: name of file, relative to root of source tree. package: name of the package serialized_pb: (str) Byte string of serialized descriptor_pb2.FileDescriptorProto. + dependencies: List of other FileDescriptors this FileDescriptor depends on. + message_types_by_name: Dict of message names of their descriptors. + enum_types_by_name: Dict of enum names and their descriptors. + extensions_by_name: Dict of extension names and their descriptors. """ - def __init__(self, name, package, options=None, serialized_pb=None): + def __init__(self, name, package, options=None, serialized_pb=None, + dependencies=None): """Constructor.""" super(FileDescriptor, self).__init__(options, 'FileOptions') @@ -659,10 +721,17 @@ class FileDescriptor(DescriptorBase): self.name = name self.package = package self.serialized_pb = serialized_pb + + self.enum_types_by_name = {} + self.extensions_by_name = {} + self.dependencies = (dependencies or []) + if (api_implementation.Type() == 'cpp' and self.serialized_pb is not None): if api_implementation.Version() == 2: - _message.BuildFile(self.serialized_pb) + # pylint: disable=protected-access + _message.Message._BuildFile(self.serialized_pb) + # pylint: enable=protected-access else: cpp_message.BuildFile(self.serialized_pb) @@ -685,29 +754,96 @@ def _ParseOptions(message, string): return message -def MakeDescriptor(desc_proto, package=''): +def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True): """Make a protobuf Descriptor given a DescriptorProto protobuf. + Handles nested descriptors. Note that this is limited to the scope of defining + a message inside of another message. Composite fields can currently only be + resolved if the message is defined in the same scope as the field. + Args: desc_proto: The descriptor_pb2.DescriptorProto protobuf message. package: Optional package name for the new message Descriptor (string). - + build_file_if_cpp: Update the C++ descriptor pool if api matches. + Set to False on recursion, so no duplicates are created. Returns: A Descriptor for protobuf messages. """ + if api_implementation.Type() == 'cpp' and build_file_if_cpp: + # The C++ implementation requires all descriptors to be backed by the same + # definition in the C++ descriptor pool. To do this, we build a + # FileDescriptorProto with the same definition as this descriptor and build + # it into the pool. + from google.protobuf import descriptor_pb2 + file_descriptor_proto = descriptor_pb2.FileDescriptorProto() + file_descriptor_proto.message_type.add().MergeFrom(desc_proto) + + # Generate a random name for this proto file to prevent conflicts with + # any imported ones. We need to specify a file name so BuildFile accepts + # our FileDescriptorProto, but it is not important what that file name + # is actually set to. + proto_name = str(uuid.uuid4()) + + if package: + file_descriptor_proto.name = os.path.join(package.replace('.', '/'), + proto_name + '.proto') + file_descriptor_proto.package = package + else: + file_descriptor_proto.name = proto_name + '.proto' + + if api_implementation.Version() == 2: + # pylint: disable=protected-access + _message.Message._BuildFile(file_descriptor_proto.SerializeToString()) + # pylint: enable=protected-access + else: + cpp_message.BuildFile(file_descriptor_proto.SerializeToString()) + full_message_name = [desc_proto.name] if package: full_message_name.insert(0, package) + + # Create Descriptors for enum types + enum_types = {} + for enum_proto in desc_proto.enum_type: + full_name = '.'.join(full_message_name + [enum_proto.name]) + enum_desc = EnumDescriptor( + enum_proto.name, full_name, None, [ + EnumValueDescriptor(enum_val.name, ii, enum_val.number) + for ii, enum_val in enumerate(enum_proto.value)]) + enum_types[full_name] = enum_desc + + # Create Descriptors for nested types + nested_types = {} + for nested_proto in desc_proto.nested_type: + full_name = '.'.join(full_message_name + [nested_proto.name]) + # Nested types are just those defined inside of the message, not all types + # used by fields in the message, so no loops are possible here. + nested_desc = MakeDescriptor(nested_proto, + package='.'.join(full_message_name), + build_file_if_cpp=False) + nested_types[full_name] = nested_desc + fields = [] for field_proto in desc_proto.field: full_name = '.'.join(full_message_name + [field_proto.name]) + enum_desc = None + nested_desc = None + if field_proto.HasField('type_name'): + type_name = field_proto.type_name + full_type_name = '.'.join(full_message_name + + [type_name[type_name.rfind('.')+1:]]) + if full_type_name in nested_types: + nested_desc = nested_types[full_type_name] + elif full_type_name in enum_types: + enum_desc = enum_types[full_type_name] + # Else type_name references a non-local type, which isn't implemented field = FieldDescriptor( field_proto.name, full_name, field_proto.number - 1, field_proto.number, field_proto.type, FieldDescriptor.ProtoTypeToCppProtoType(field_proto.type), - field_proto.label, None, None, None, None, False, None, + field_proto.label, None, nested_desc, enum_desc, None, False, None, has_default_value=False) fields.append(field) desc_name = '.'.join(full_message_name) return Descriptor(desc_proto.name, desc_name, None, None, fields, - [], [], []) + nested_types.values(), enum_types.values(), []) diff --git a/python/google/protobuf/descriptor_database.py b/python/google/protobuf/descriptor_database.py index 8665d3c5..9f5a117c 100644 --- a/python/google/protobuf/descriptor_database.py +++ b/python/google/protobuf/descriptor_database.py @@ -33,6 +33,14 @@ __author__ = 'matthewtoia@google.com (Matt Toia)' +class Error(Exception): + pass + + +class DescriptorDatabaseConflictingDefinitionError(Error): + """Raised when a proto is added with the same name & different descriptor.""" + + class DescriptorDatabase(object): """A container accepting FileDescriptorProtos and maps DescriptorProtos.""" @@ -45,9 +53,18 @@ class DescriptorDatabase(object): Args: file_desc_proto: The FileDescriptorProto to add. + Raises: + DescriptorDatabaseException: if an attempt is made to add a proto + with the same name but different definition than an exisiting + proto in the database. """ + proto_name = file_desc_proto.name + if proto_name not in self._file_desc_protos_by_file: + self._file_desc_protos_by_file[proto_name] = file_desc_proto + elif self._file_desc_protos_by_file[proto_name] != file_desc_proto: + raise DescriptorDatabaseConflictingDefinitionError( + '%s already added, but with different descriptor.' % proto_name) - self._file_desc_protos_by_file[file_desc_proto.name] = file_desc_proto package = file_desc_proto.package for message in file_desc_proto.message_type: self._file_desc_protos_by_symbol.update( diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py index 8f1f4457..372f458f 100644 --- a/python/google/protobuf/descriptor_pool.py +++ b/python/google/protobuf/descriptor_pool.py @@ -49,13 +49,34 @@ Below is a straightforward example on how to use this class: The message descriptor can be used in conjunction with the message_factory module in order to create a protocol buffer class that can be encoded and decoded. + +If you want to get a Python class for the specified proto, use the +helper functions inside google.protobuf.message_factory +directly instead of this class. """ __author__ = 'matthewtoia@google.com (Matt Toia)' -from google.protobuf import descriptor_pb2 +import sys + from google.protobuf import descriptor from google.protobuf import descriptor_database +from google.protobuf import text_encoding + + +def _NormalizeFullyQualifiedName(name): + """Remove leading period from fully-qualified type name. + + Due to b/13860351 in descriptor_database.py, types in the root namespace are + generated with a leading period. This function removes that prefix. + + Args: + name: A str, the fully-qualified symbol name. + + Returns: + A str, the normalized fully-qualified symbol name. + """ + return name.lstrip('.') class DescriptorPool(object): @@ -89,6 +110,51 @@ class DescriptorPool(object): self._internal_db.Add(file_desc_proto) + def AddDescriptor(self, desc): + """Adds a Descriptor to the pool, non-recursively. + + If the Descriptor contains nested messages or enums, the caller must + explicitly register them. This method also registers the FileDescriptor + associated with the message. + + Args: + desc: A Descriptor. + """ + if not isinstance(desc, descriptor.Descriptor): + raise TypeError('Expected instance of descriptor.Descriptor.') + + self._descriptors[desc.full_name] = desc + self.AddFileDescriptor(desc.file) + + def AddEnumDescriptor(self, enum_desc): + """Adds an EnumDescriptor to the pool. + + This method also registers the FileDescriptor associated with the message. + + Args: + enum_desc: An EnumDescriptor. + """ + + if not isinstance(enum_desc, descriptor.EnumDescriptor): + raise TypeError('Expected instance of descriptor.EnumDescriptor.') + + self._enum_descriptors[enum_desc.full_name] = enum_desc + self.AddFileDescriptor(enum_desc.file) + + def AddFileDescriptor(self, file_desc): + """Adds a FileDescriptor to the pool, non-recursively. + + If the FileDescriptor contains messages or enums, the caller must explicitly + register them. + + Args: + file_desc: A FileDescriptor. + """ + + if not isinstance(file_desc, descriptor.FileDescriptor): + raise TypeError('Expected instance of descriptor.FileDescriptor.') + self._file_descriptors[file_desc.name] = file_desc + def FindFileByName(self, file_name): """Gets a FileDescriptor by file name. @@ -102,9 +168,15 @@ class DescriptorPool(object): KeyError: if the file can not be found in the pool. """ + try: + return self._file_descriptors[file_name] + except KeyError: + pass + try: file_proto = self._internal_db.FindFileByName(file_name) - except KeyError as error: + except KeyError: + _, error, _ = sys.exc_info() #PY25 compatible for GAE. if self._descriptor_db: file_proto = self._descriptor_db.FindFileByName(file_name) else: @@ -126,9 +198,21 @@ class DescriptorPool(object): KeyError: if the file can not be found in the pool. """ + symbol = _NormalizeFullyQualifiedName(symbol) + try: + return self._descriptors[symbol].file + except KeyError: + pass + + try: + return self._enum_descriptors[symbol].file + except KeyError: + pass + try: file_proto = self._internal_db.FindFileContainingSymbol(symbol) - except KeyError as error: + except KeyError: + _, error, _ = sys.exc_info() #PY25 compatible for GAE. if self._descriptor_db: file_proto = self._descriptor_db.FindFileContainingSymbol(symbol) else: @@ -147,7 +231,7 @@ class DescriptorPool(object): The descriptor for the named type. """ - full_name = full_name.lstrip('.') # fix inconsistent qualified name formats + full_name = _NormalizeFullyQualifiedName(full_name) if full_name not in self._descriptors: self.FindFileContainingSymbol(full_name) return self._descriptors[full_name] @@ -162,7 +246,7 @@ class DescriptorPool(object): The enum descriptor for the named type. """ - full_name = full_name.lstrip('.') # fix inconsistent qualified name formats + full_name = _NormalizeFullyQualifiedName(full_name) if full_name not in self._enum_descriptors: self.FindFileContainingSymbol(full_name) return self._enum_descriptors[full_name] @@ -181,46 +265,56 @@ class DescriptorPool(object): """ if file_proto.name not in self._file_descriptors: + built_deps = list(self._GetDeps(file_proto.dependency)) + direct_deps = [self.FindFileByName(n) for n in file_proto.dependency] + file_descriptor = descriptor.FileDescriptor( name=file_proto.name, package=file_proto.package, options=file_proto.options, - serialized_pb=file_proto.SerializeToString()) + serialized_pb=file_proto.SerializeToString(), + dependencies=direct_deps) scope = {} - dependencies = list(self._GetDeps(file_proto)) - for dependency in dependencies: - dep_desc = self.FindFileByName(dependency.name) - dep_proto = descriptor_pb2.FileDescriptorProto.FromString( - dep_desc.serialized_pb) - package = '.' + dep_proto.package - package_prefix = package + '.' - - def _strip_package(symbol): - if symbol.startswith(package_prefix): - return symbol[len(package_prefix):] - return symbol - - symbols = list(self._ExtractSymbols(dep_proto.message_type, package)) - scope.update(symbols) - scope.update((_strip_package(k), v) for k, v in symbols) - - symbols = list(self._ExtractEnums(dep_proto.enum_type, package)) - scope.update(symbols) - scope.update((_strip_package(k), v) for k, v in symbols) + # This loop extracts all the message and enum types from all the + # dependencoes of the file_proto. This is necessary to create the + # scope of available message types when defining the passed in + # file proto. + for dependency in built_deps: + scope.update(self._ExtractSymbols( + dependency.message_types_by_name.values())) + scope.update((_PrefixWithDot(enum.full_name), enum) + for enum in dependency.enum_types_by_name.values()) for message_type in file_proto.message_type: message_desc = self._ConvertMessageDescriptor( message_type, file_proto.package, file_descriptor, scope) file_descriptor.message_types_by_name[message_desc.name] = message_desc + for enum_type in file_proto.enum_type: - self._ConvertEnumDescriptor(enum_type, file_proto.package, - file_descriptor, None, scope) - for desc_proto in self._ExtractMessages(file_proto.message_type): - self._SetFieldTypes(desc_proto, scope) + file_descriptor.enum_types_by_name[enum_type.name] = ( + self._ConvertEnumDescriptor(enum_type, file_proto.package, + file_descriptor, None, scope)) + + for index, extension_proto in enumerate(file_proto.extension): + extension_desc = self.MakeFieldDescriptor( + extension_proto, file_proto.package, index, is_extension=True) + extension_desc.containing_type = self._GetTypeFromScope( + file_descriptor.package, extension_proto.extendee, scope) + self.SetFieldType(extension_proto, extension_desc, + file_descriptor.package, scope) + file_descriptor.extensions_by_name[extension_desc.name] = extension_desc + + for desc_proto in file_proto.message_type: + self.SetAllFieldTypes(file_proto.package, desc_proto, scope) + + if file_proto.package: + desc_proto_prefix = _PrefixWithDot(file_proto.package) + else: + desc_proto_prefix = '' for desc_proto in file_proto.message_type: - desc = scope[desc_proto.name] + desc = self._GetTypeFromScope(desc_proto_prefix, desc_proto.name, scope) file_descriptor.message_types_by_name[desc_proto.name] = desc self.Add(file_proto) self._file_descriptors[file_proto.name] = file_descriptor @@ -260,10 +354,15 @@ class DescriptorPool(object): enums = [ self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, scope) for enum in desc_proto.enum_type] - fields = [self._MakeFieldDescriptor(field, desc_name, index) + fields = [self.MakeFieldDescriptor(field, desc_name, index) for index, field in enumerate(desc_proto.field)] - extensions = [self._MakeFieldDescriptor(extension, desc_name, True) - for index, extension in enumerate(desc_proto.extension)] + extensions = [ + self.MakeFieldDescriptor(extension, desc_name, index, is_extension=True) + for index, extension in enumerate(desc_proto.extension)] + oneofs = [ + descriptor.OneofDescriptor(desc.name, '.'.join((desc_name, desc.name)), + index, None, []) + for index, desc in enumerate(desc_proto.oneof_decl)] extension_ranges = [(r.start, r.end) for r in desc_proto.extension_range] if extension_ranges: is_extendable = True @@ -275,6 +374,7 @@ class DescriptorPool(object): filename=file_name, containing_type=None, fields=fields, + oneofs=oneofs, nested_types=nested, enum_types=enums, extensions=extensions, @@ -288,8 +388,13 @@ class DescriptorPool(object): nested.containing_type = desc for enum in desc.enum_types: enum.containing_type = desc - scope[desc_proto.name] = desc - scope['.' + desc_name] = desc + for field_index, field_desc in enumerate(desc_proto.field): + if field_desc.HasField('oneof_index'): + oneof_index = field_desc.oneof_index + oneofs[oneof_index].fields.append(fields[field_index]) + fields[field_index].containing_oneof = oneofs[oneof_index] + + scope[_PrefixWithDot(desc_name)] = desc self._descriptors[desc_name] = desc return desc @@ -327,13 +432,12 @@ class DescriptorPool(object): values=values, containing_type=containing_type, options=enum_proto.options) - scope[enum_proto.name] = desc scope['.%s' % enum_name] = desc self._enum_descriptors[enum_name] = desc return desc - def _MakeFieldDescriptor(self, field_proto, message_name, index, - is_extension=False): + def MakeFieldDescriptor(self, field_proto, message_name, index, + is_extension=False): """Creates a field descriptor from a FieldDescriptorProto. For message and enum type fields, this method will do a look up @@ -374,65 +478,93 @@ class DescriptorPool(object): extension_scope=None, options=field_proto.options) - def _SetFieldTypes(self, desc_proto, scope): - """Sets the field's type, cpp_type, message_type and enum_type. + def SetAllFieldTypes(self, package, desc_proto, scope): + """Sets all the descriptor's fields's types. + + This method also sets the containing types on any extensions. Args: + package: The current package of desc_proto. desc_proto: The message descriptor to update. scope: Enclosing scope of available types. """ - desc = scope[desc_proto.name] - for field_proto, field_desc in zip(desc_proto.field, desc.fields): - if field_proto.type_name: - type_name = field_proto.type_name - if type_name not in scope: - type_name = '.' + type_name - desc = scope[type_name] - else: - desc = None + package = _PrefixWithDot(package) - if not field_proto.HasField('type'): - if isinstance(desc, descriptor.Descriptor): - field_proto.type = descriptor.FieldDescriptor.TYPE_MESSAGE - else: - field_proto.type = descriptor.FieldDescriptor.TYPE_ENUM - - field_desc.cpp_type = descriptor.FieldDescriptor.ProtoTypeToCppProtoType( - field_proto.type) - - if (field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE - or field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP): - field_desc.message_type = desc - - if field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: - field_desc.enum_type = desc - - if field_proto.label == descriptor.FieldDescriptor.LABEL_REPEATED: - field_desc.has_default = False - field_desc.default_value = [] - elif field_proto.HasField('default_value'): - field_desc.has_default = True - if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or - field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT): - field_desc.default_value = float(field_proto.default_value) - elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING: - field_desc.default_value = field_proto.default_value - elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL: - field_desc.default_value = field_proto.default_value.lower() == 'true' - elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: - field_desc.default_value = field_desc.enum_type.values_by_name[ - field_proto.default_value].index - else: - field_desc.default_value = int(field_proto.default_value) - else: - field_desc.has_default = False - field_desc.default_value = None + main_desc = self._GetTypeFromScope(package, desc_proto.name, scope) - field_desc.type = field_proto.type + if package == '.': + nested_package = _PrefixWithDot(desc_proto.name) + else: + nested_package = '.'.join([package, desc_proto.name]) + + for field_proto, field_desc in zip(desc_proto.field, main_desc.fields): + self.SetFieldType(field_proto, field_desc, nested_package, scope) + + for extension_proto, extension_desc in ( + zip(desc_proto.extension, main_desc.extensions)): + extension_desc.containing_type = self._GetTypeFromScope( + nested_package, extension_proto.extendee, scope) + self.SetFieldType(extension_proto, extension_desc, nested_package, scope) for nested_type in desc_proto.nested_type: - self._SetFieldTypes(nested_type, scope) + self.SetAllFieldTypes(nested_package, nested_type, scope) + + def SetFieldType(self, field_proto, field_desc, package, scope): + """Sets the field's type, cpp_type, message_type and enum_type. + + Args: + field_proto: Data about the field in proto format. + field_desc: The descriptor to modiy. + package: The package the field's container is in. + scope: Enclosing scope of available types. + """ + if field_proto.type_name: + desc = self._GetTypeFromScope(package, field_proto.type_name, scope) + else: + desc = None + + if not field_proto.HasField('type'): + if isinstance(desc, descriptor.Descriptor): + field_proto.type = descriptor.FieldDescriptor.TYPE_MESSAGE + else: + field_proto.type = descriptor.FieldDescriptor.TYPE_ENUM + + field_desc.cpp_type = descriptor.FieldDescriptor.ProtoTypeToCppProtoType( + field_proto.type) + + if (field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE + or field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP): + field_desc.message_type = desc + + if field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: + field_desc.enum_type = desc + + if field_proto.label == descriptor.FieldDescriptor.LABEL_REPEATED: + field_desc.has_default_value = False + field_desc.default_value = [] + elif field_proto.HasField('default_value'): + field_desc.has_default_value = True + if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or + field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT): + field_desc.default_value = float(field_proto.default_value) + elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING: + field_desc.default_value = field_proto.default_value + elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL: + field_desc.default_value = field_proto.default_value.lower() == 'true' + elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: + field_desc.default_value = field_desc.enum_type.values_by_name[ + field_proto.default_value].index + elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES: + field_desc.default_value = text_encoding.CUnescape( + field_proto.default_value) + else: + field_desc.default_value = int(field_proto.default_value) + else: + field_desc.has_default_value = False + field_desc.default_value = None + + field_desc.type = field_proto.type def _MakeEnumValueDescriptor(self, value_proto, index): """Creates a enum value descriptor object from a enum value proto. @@ -452,76 +584,60 @@ class DescriptorPool(object): options=value_proto.options, type=None) - def _ExtractSymbols(self, desc_protos, package): + def _ExtractSymbols(self, descriptors): """Pulls out all the symbols from descriptor protos. Args: - desc_protos: The protos to extract symbols from. - package: The package containing the descriptor type. + descriptors: The messages to extract descriptors from. Yields: A two element tuple of the type name and descriptor object. """ - for desc_proto in desc_protos: - if package: - message_name = '.'.join((package, desc_proto.name)) - else: - message_name = desc_proto.name - message_desc = self.FindMessageTypeByName(message_name) - yield (message_name, message_desc) - for symbol in self._ExtractSymbols(desc_proto.nested_type, message_name): - yield symbol - for symbol in self._ExtractEnums(desc_proto.enum_type, message_name): + for desc in descriptors: + yield (_PrefixWithDot(desc.full_name), desc) + for symbol in self._ExtractSymbols(desc.nested_types): yield symbol + for enum in desc.enum_types: + yield (_PrefixWithDot(enum.full_name), enum) - def _ExtractEnums(self, enum_protos, package): - """Pulls out all the symbols from enum protos. + def _GetDeps(self, dependencies): + """Recursively finds dependencies for file protos. Args: - enum_protos: The protos to extract symbols from. - package: The package containing the enum type. + dependencies: The names of the files being depended on. Yields: - A two element tuple of the type name and enum descriptor object. + Each direct and indirect dependency. """ - for enum_proto in enum_protos: - if package: - enum_name = '.'.join((package, enum_proto.name)) - else: - enum_name = enum_proto.name - enum_desc = self.FindEnumTypeByName(enum_name) - yield (enum_name, enum_desc) + for dependency in dependencies: + dep_desc = self.FindFileByName(dependency) + yield dep_desc + for parent_dep in dep_desc.dependencies: + yield parent_dep - def _ExtractMessages(self, desc_protos): - """Pulls out all the message protos from descriptos. + def _GetTypeFromScope(self, package, type_name, scope): + """Finds a given type name in the current scope. Args: - desc_protos: The protos to extract symbols from. + package: The package the proto should be located in. + type_name: The name of the type to be found in the scope. + scope: Dict mapping short and full symbols to message and enum types. - Yields: - Descriptor protos. + Returns: + The descriptor for the requested type. """ + if type_name not in scope: + components = _PrefixWithDot(package).split('.') + while components: + possible_match = '.'.join(components + [type_name]) + if possible_match in scope: + type_name = possible_match + break + else: + components.pop(-1) + return scope[type_name] - for desc_proto in desc_protos: - yield desc_proto - for message in self._ExtractMessages(desc_proto.nested_type): - yield message - - def _GetDeps(self, file_proto): - """Recursively finds dependencies for file protos. - - Args: - file_proto: The proto to get dependencies from. - - Yields: - Each direct and indirect dependency. - """ - for dependency in file_proto.dependency: - dep_desc = self.FindFileByName(dependency) - dep_proto = descriptor_pb2.FileDescriptorProto.FromString( - dep_desc.serialized_pb) - yield dep_proto - for parent_dep in self._GetDeps(dep_proto): - yield parent_dep +def _PrefixWithDot(name): + return name if name.startswith('.') else '.%s' % name diff --git a/python/google/protobuf/internal/api_implementation.cc b/python/google/protobuf/internal/api_implementation.cc new file mode 100644 index 00000000..ad6fd9c8 --- /dev/null +++ b/python/google/protobuf/internal/api_implementation.cc @@ -0,0 +1,139 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include + +namespace google { +namespace protobuf { +namespace python { + +// Version constant. +// This is either 0 for python, 1 for CPP V1, 2 for CPP V2. +// +// 0 is default and is equivalent to +// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python +// +// 1 is set with -DPYTHON_PROTO2_CPP_IMPL_V1 and is equivalent to +// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp +// and +// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION=1 +// +// 2 is set with -DPYTHON_PROTO2_CPP_IMPL_V2 and is equivalent to +// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp +// and +// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION=2 +#ifdef PYTHON_PROTO2_CPP_IMPL_V1 +#if PY_MAJOR_VERSION >= 3 +#error "PYTHON_PROTO2_CPP_IMPL_V1 is not supported under Python 3." +#endif +static int kImplVersion = 1; +#else +#ifdef PYTHON_PROTO2_CPP_IMPL_V2 +static int kImplVersion = 2; +#else +#ifdef PYTHON_PROTO2_PYTHON_IMPL +static int kImplVersion = 0; +#else + +// The defaults are set here. Python 3 uses the fast C++ APIv2 by default. +// Python 2 still uses the Python version by default until some compatibility +// issues can be worked around. +#if PY_MAJOR_VERSION >= 3 +static int kImplVersion = 2; +#else +static int kImplVersion = 0; +#endif + +#endif // PYTHON_PROTO2_PYTHON_IMPL +#endif // PYTHON_PROTO2_CPP_IMPL_V2 +#endif // PYTHON_PROTO2_CPP_IMPL_V1 + +static const char* kImplVersionName = "api_version"; + +static const char* kModuleName = "_api_implementation"; +static const char kModuleDocstring[] = +"_api_implementation is a module that exposes compile-time constants that\n" +"determine the default API implementation to use for Python proto2.\n" +"\n" +"It complements api_implementation.py by setting defaults using compile-time\n" +"constants defined in C, such that one can set defaults at compilation\n" +"(e.g. with blaze flag --copt=-DPYTHON_PROTO2_CPP_IMPL_V2)."; + +#if PY_MAJOR_VERSION >= 3 +static struct PyModuleDef _module = { + PyModuleDef_HEAD_INIT, + kModuleName, + kModuleDocstring, + -1, + NULL, + NULL, + NULL, + NULL, + NULL +}; +#define INITFUNC PyInit__api_implementation +#define INITFUNC_ERRORVAL NULL +#else +#define INITFUNC init_api_implementation +#define INITFUNC_ERRORVAL +#endif + +extern "C" { + PyMODINIT_FUNC INITFUNC() { +#if PY_MAJOR_VERSION >= 3 + PyObject *module = PyModule_Create(&_module); +#else + PyObject *module = Py_InitModule3( + const_cast(kModuleName), + NULL, + const_cast(kModuleDocstring)); +#endif + if (module == NULL) { + return INITFUNC_ERRORVAL; + } + + // Adds the module variable "api_version". + if (PyModule_AddIntConstant( + module, + const_cast(kImplVersionName), + kImplVersion)) +#if PY_MAJOR_VERSION < 3 + return; +#else + { Py_DECREF(module); return NULL; } + + return module; +#endif + } +} + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/python/google/protobuf/internal/api_implementation.py b/python/google/protobuf/internal/api_implementation.py index ce02a329..cbb85747 100755 --- a/python/google/protobuf/internal/api_implementation.py +++ b/python/google/protobuf/internal/api_implementation.py @@ -28,41 +28,44 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Determine which implementation of the protobuf API is used in this process. """ -This module is the central entity that determines which implementation of the -API is used. -""" - -__author__ = 'petar@google.com (Petar Petrov)' import os +import sys + +try: + # pylint: disable=g-import-not-at-top + from google.protobuf.internal import _api_implementation + # The compile-time constants in the _api_implementation module can be used to + # switch to a certain implementation of the Python API at build time. + _api_version = _api_implementation.api_version + del _api_implementation +except ImportError: + _api_version = 0 + +_default_implementation_type = ( + 'python' if _api_version == 0 else 'cpp') +_default_version_str = ( + '1' if _api_version <= 1 else '2') + # This environment variable can be used to switch to a certain implementation -# of the Python API. Right now only 'python' and 'cpp' are valid values. Any -# other value will be ignored. +# of the Python API, overriding the compile-time constants in the +# _api_implementation module. Right now only 'python' and 'cpp' are valid +# values. Any other value will be ignored. _implementation_type = os.getenv('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION', - 'python') - + _default_implementation_type) if _implementation_type != 'python': - # For now, by default use the pure-Python implementation. - # The code below checks if the C extension is available and - # uses it if it is available. _implementation_type = 'cpp' - ## Determine automatically which implementation to use. - #try: - # from google.protobuf.internal import cpp_message - # _implementation_type = 'cpp' - #except ImportError, e: - # _implementation_type = 'python' - # This environment variable can be used to switch between the two -# 'cpp' implementations. Right now only 1 and 2 are valid values. Any -# other value will be ignored. +# 'cpp' implementations, overriding the compile-time constants in the +# _api_implementation module. Right now only 1 and 2 are valid values. Any other +# value will be ignored. _implementation_version_str = os.getenv( 'PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION', - '1') - + _default_version_str) if _implementation_version_str not in ('1', '2'): raise ValueError( @@ -70,11 +73,9 @@ if _implementation_version_str not in ('1', '2'): _implementation_version_str + "' (supported versions: 1, 2)" ) - _implementation_version = int(_implementation_version_str) - # Usage of this function is discouraged. Clients shouldn't care which # implementation of the API is in use. Note that there is no guarantee # that differences between APIs will be maintained. @@ -82,6 +83,7 @@ _implementation_version = int(_implementation_version_str) def Type(): return _implementation_type + # See comment on 'Type' above. def Version(): return _implementation_version diff --git a/python/google/protobuf/internal/api_implementation_default_test.py b/python/google/protobuf/internal/api_implementation_default_test.py new file mode 100644 index 00000000..b2b41284 --- /dev/null +++ b/python/google/protobuf/internal/api_implementation_default_test.py @@ -0,0 +1,63 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Test that the api_implementation defaults are what we expect.""" + +import os +import sys +# Clear environment implementation settings before the google3 imports. +os.environ.pop('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION', None) +os.environ.pop('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION', None) + +# pylint: disable=g-import-not-at-top +from google.apputils import basetest +from google.protobuf.internal import api_implementation + + +class ApiImplementationDefaultTest(basetest.TestCase): + + if sys.version_info.major <= 2: + + def testThatPythonIsTheDefault(self): + """If -DPYTHON_PROTO_*IMPL* was given at build time, this may fail.""" + self.assertEqual('python', api_implementation.Type()) + + else: + + def testThatCppApiV2IsTheDefault(self): + """If -DPYTHON_PROTO_*IMPL* was given at build time, this may fail.""" + self.assertEqual('cpp', api_implementation.Type()) + self.assertEqual(2, api_implementation.Version()) + + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py index 34b35f8a..5797e81b 100755 --- a/python/google/protobuf/internal/containers.py +++ b/python/google/protobuf/internal/containers.py @@ -108,15 +108,13 @@ class RepeatedScalarFieldContainer(BaseContainer): def append(self, value): """Appends an item to the list. Similar to list.append().""" - self._type_checker.CheckValue(value) - self._values.append(value) + self._values.append(self._type_checker.CheckValue(value)) if not self._message_listener.dirty: self._message_listener.Modified() def insert(self, key, value): """Inserts the item at the specified position. Similar to list.insert().""" - self._type_checker.CheckValue(value) - self._values.insert(key, value) + self._values.insert(key, self._type_checker.CheckValue(value)) if not self._message_listener.dirty: self._message_listener.Modified() @@ -127,8 +125,7 @@ class RepeatedScalarFieldContainer(BaseContainer): new_values = [] for elem in elem_seq: - self._type_checker.CheckValue(elem) - new_values.append(elem) + new_values.append(self._type_checker.CheckValue(elem)) self._values.extend(new_values) self._message_listener.Modified() @@ -146,9 +143,13 @@ class RepeatedScalarFieldContainer(BaseContainer): def __setitem__(self, key, value): """Sets the item on the specified position.""" - self._type_checker.CheckValue(value) - self._values[key] = value - self._message_listener.Modified() + if isinstance(key, slice): # PY3 + if key.step is not None: + raise ValueError('Extended slices not supported') + self.__setslice__(key.start, key.stop, value) + else: + self._values[key] = self._type_checker.CheckValue(value) + self._message_listener.Modified() def __getslice__(self, start, stop): """Retrieves the subset of items from between the specified indices.""" @@ -158,8 +159,7 @@ class RepeatedScalarFieldContainer(BaseContainer): """Sets the subset of items from between the specified indices.""" new_values = [] for value in values: - self._type_checker.CheckValue(value) - new_values.append(value) + new_values.append(self._type_checker.CheckValue(value)) self._values[start:stop] = new_values self._message_listener.Modified() diff --git a/python/google/protobuf/internal/cpp_message.py b/python/google/protobuf/internal/cpp_message.py index 23ab9ba4..8eb38ca4 100755 --- a/python/google/protobuf/internal/cpp_message.py +++ b/python/google/protobuf/internal/cpp_message.py @@ -610,7 +610,7 @@ def _AddMessageMethods(message_descriptor, cls): return self._cmsg.FindInitializationErrors() def __str__(self): - return self._cmsg.DebugString() + return str(self._cmsg) def __eq__(self, other): if self is other: diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py index cb6f5729..651ee0d4 100755 --- a/python/google/protobuf/internal/decoder.py +++ b/python/google/protobuf/internal/decoder.py @@ -28,6 +28,10 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#PY25 compatible for GAE. +# +# Copyright 2009 Google Inc. All Rights Reserved. + """Code for decoding protocol buffer primitives. This code is very similar to encoder.py -- read the docs for that module first. @@ -81,6 +85,8 @@ we repeatedly read a tag, look up the corresponding decoder, and invoke it. __author__ = 'kenton@google.com (Kenton Varda)' import struct +import sys ##PY25 +_PY2 = sys.version_info[0] < 3 ##PY25 from google.protobuf.internal import encoder from google.protobuf.internal import wire_format from google.protobuf import message @@ -98,7 +104,7 @@ _NAN = _POS_INF * 0 _DecodeError = message.DecodeError -def _VarintDecoder(mask): +def _VarintDecoder(mask, result_type): """Return an encoder for a basic varint value (does not include tag). Decoded values will be bitwise-anded with the given mask before being @@ -109,15 +115,18 @@ def _VarintDecoder(mask): """ local_ord = ord + py2 = _PY2 ##PY25 +##!PY25 py2 = str is bytes def DecodeVarint(buffer, pos): result = 0 shift = 0 while 1: - b = local_ord(buffer[pos]) + b = local_ord(buffer[pos]) if py2 else buffer[pos] result |= ((b & 0x7f) << shift) pos += 1 if not (b & 0x80): result &= mask + result = result_type(result) return (result, pos) shift += 7 if shift >= 64: @@ -125,15 +134,17 @@ def _VarintDecoder(mask): return DecodeVarint -def _SignedVarintDecoder(mask): +def _SignedVarintDecoder(mask, result_type): """Like _VarintDecoder() but decodes signed values.""" local_ord = ord + py2 = _PY2 ##PY25 +##!PY25 py2 = str is bytes def DecodeVarint(buffer, pos): result = 0 shift = 0 while 1: - b = local_ord(buffer[pos]) + b = local_ord(buffer[pos]) if py2 else buffer[pos] result |= ((b & 0x7f) << shift) pos += 1 if not (b & 0x80): @@ -142,19 +153,23 @@ def _SignedVarintDecoder(mask): result |= ~mask else: result &= mask + result = result_type(result) return (result, pos) shift += 7 if shift >= 64: raise _DecodeError('Too many bytes when decoding varint.') return DecodeVarint +# We force 32-bit values to int and 64-bit values to long to make +# alternate implementations where the distinction is more significant +# (e.g. the C++ implementation) simpler. -_DecodeVarint = _VarintDecoder((1 << 64) - 1) -_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1) +_DecodeVarint = _VarintDecoder((1 << 64) - 1, long) +_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1, long) # Use these versions for values which must be limited to 32 bits. -_DecodeVarint32 = _VarintDecoder((1 << 32) - 1) -_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1) +_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int) +_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1, int) def ReadTag(buffer, pos): @@ -168,8 +183,10 @@ def ReadTag(buffer, pos): use that, but not in Python. """ + py2 = _PY2 ##PY25 +##!PY25 py2 = str is bytes start = pos - while ord(buffer[pos]) & 0x80: + while (ord(buffer[pos]) if py2 else buffer[pos]) & 0x80: pos += 1 pos += 1 return (buffer[start:pos], pos) @@ -284,6 +301,7 @@ def _FloatDecoder(): """ local_unpack = struct.unpack + b = (lambda x:x) if _PY2 else lambda x:x.encode('latin1') ##PY25 def InnerDecode(buffer, pos): # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign @@ -294,13 +312,17 @@ def _FloatDecoder(): # If this value has all its exponent bits set, then it's non-finite. # In Python 2.4, struct.unpack will convert it to a finite 64-bit value. # To avoid that, we parse it specially. - if ((float_bytes[3] in '\x7F\xFF') - and (float_bytes[2] >= '\x80')): + if ((float_bytes[3:4] in b('\x7F\xFF')) ##PY25 +##!PY25 if ((float_bytes[3:4] in b'\x7F\xFF') + and (float_bytes[2:3] >= b('\x80'))): ##PY25 +##!PY25 and (float_bytes[2:3] >= b'\x80')): # If at least one significand bit is set... - if float_bytes[0:3] != '\x00\x00\x80': + if float_bytes[0:3] != b('\x00\x00\x80'): ##PY25 +##!PY25 if float_bytes[0:3] != b'\x00\x00\x80': return (_NAN, new_pos) # If sign bit is set... - if float_bytes[3] == '\xFF': + if float_bytes[3:4] == b('\xFF'): ##PY25 +##!PY25 if float_bytes[3:4] == b'\xFF': return (_NEG_INF, new_pos) return (_POS_INF, new_pos) @@ -319,6 +341,7 @@ def _DoubleDecoder(): """ local_unpack = struct.unpack + b = (lambda x:x) if _PY2 else lambda x:x.encode('latin1') ##PY25 def InnerDecode(buffer, pos): # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign @@ -329,9 +352,12 @@ def _DoubleDecoder(): # If this value has all its exponent bits set and at least one significand # bit set, it's not a number. In Python 2.4, struct.unpack will treat it # as inf or -inf. To avoid that, we treat it specially. - if ((double_bytes[7] in '\x7F\xFF') - and (double_bytes[6] >= '\xF0') - and (double_bytes[0:7] != '\x00\x00\x00\x00\x00\x00\xF0')): +##!PY25 if ((double_bytes[7:8] in b'\x7F\xFF') +##!PY25 and (double_bytes[6:7] >= b'\xF0') +##!PY25 and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')): + if ((double_bytes[7:8] in b('\x7F\xFF')) ##PY25 + and (double_bytes[6:7] >= b('\xF0')) ##PY25 + and (double_bytes[0:7] != b('\x00\x00\x00\x00\x00\x00\xF0'))): ##PY25 return (_NAN, new_pos) # Note that we expect someone up-stack to catch struct.error and convert @@ -342,10 +368,86 @@ def _DoubleDecoder(): return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode) +def EnumDecoder(field_number, is_repeated, is_packed, key, new_default): + enum_type = key.enum_type + if is_packed: + local_DecodeVarint = _DecodeVarint + def DecodePackedField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + (endpoint, pos) = local_DecodeVarint(buffer, pos) + endpoint += pos + if endpoint > end: + raise _DecodeError('Truncated message.') + while pos < endpoint: + value_start_pos = pos + (element, pos) = _DecodeSignedVarint32(buffer, pos) + if element in enum_type.values_by_number: + value.append(element) + else: + if not message._unknown_fields: + message._unknown_fields = [] + tag_bytes = encoder.TagBytes(field_number, + wire_format.WIRETYPE_VARINT) + message._unknown_fields.append( + (tag_bytes, buffer[value_start_pos:pos])) + if pos > endpoint: + if element in enum_type.values_by_number: + del value[-1] # Discard corrupt value. + else: + del message._unknown_fields[-1] + raise _DecodeError('Packed element was truncated.') + return pos + return DecodePackedField + elif is_repeated: + tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) + tag_len = len(tag_bytes) + def DecodeRepeatedField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + while 1: + (element, new_pos) = _DecodeSignedVarint32(buffer, pos) + if element in enum_type.values_by_number: + value.append(element) + else: + if not message._unknown_fields: + message._unknown_fields = [] + message._unknown_fields.append( + (tag_bytes, buffer[pos:new_pos])) + # Predict that the next tag is another copy of the same repeated + # field. + pos = new_pos + tag_len + if buffer[new_pos:pos] != tag_bytes or new_pos >= end: + # Prediction failed. Return. + if new_pos > end: + raise _DecodeError('Truncated message.') + return new_pos + return DecodeRepeatedField + else: + def DecodeField(buffer, pos, end, message, field_dict): + value_start_pos = pos + (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) + if pos > end: + raise _DecodeError('Truncated message.') + if enum_value in enum_type.values_by_number: + field_dict[key] = enum_value + else: + if not message._unknown_fields: + message._unknown_fields = [] + tag_bytes = encoder.TagBytes(field_number, + wire_format.WIRETYPE_VARINT) + message._unknown_fields.append( + (tag_bytes, buffer[value_start_pos:pos])) + return pos + return DecodeField + + # -------------------------------------------------------------------- -Int32Decoder = EnumDecoder = _SimpleDecoder( +Int32Decoder = _SimpleDecoder( wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32) Int64Decoder = _SimpleDecoder( @@ -380,6 +482,14 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default): local_DecodeVarint = _DecodeVarint local_unicode = unicode + def _ConvertToUnicode(byte_str): + try: + return local_unicode(byte_str, 'utf-8') + except UnicodeDecodeError, e: + # add more information to the error message and re-raise it. + e.reason = '%s in field: %s' % (e, key.full_name) + raise + assert not is_packed if is_repeated: tag_bytes = encoder.TagBytes(field_number, @@ -394,7 +504,7 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default): new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated string.') - value.append(local_unicode(buffer[pos:new_pos], 'utf-8')) + value.append(_ConvertToUnicode(buffer[pos:new_pos])) # Predict that the next tag is another copy of the same repeated field. pos = new_pos + tag_len if buffer[new_pos:pos] != tag_bytes or new_pos == end: @@ -407,7 +517,7 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default): new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated string.') - field_dict[key] = local_unicode(buffer[pos:new_pos], 'utf-8') + field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos]) return new_pos return DecodeField @@ -631,8 +741,10 @@ def MessageSetItemDecoder(extensions_by_number): def _SkipVarint(buffer, pos, end): """Skip a varint value. Returns the new position.""" - - while ord(buffer[pos]) & 0x80: + # Previously ord(buffer[pos]) raised IndexError when pos is out of range. + # With this code, ord(b'') raises TypeError. Both are handled in + # python_message.py to generate a 'Truncated message' error. + while ord(buffer[pos:pos+1]) & 0x80: pos += 1 pos += 1 if pos > end: @@ -699,7 +811,6 @@ def _FieldSkipper(): ] wiretype_mask = wire_format.TAG_TYPE_MASK - local_ord = ord def SkipField(buffer, pos, end, tag_bytes): """Skips a field with the specified tag. @@ -712,7 +823,7 @@ def _FieldSkipper(): """ # The wire type is always in the first byte since varints are little-endian. - wire_type = local_ord(tag_bytes[0]) & wiretype_mask + wire_type = ord(tag_bytes[0:1]) & wiretype_mask return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end) return SkipField diff --git a/python/google/protobuf/internal/descriptor_cpp2_test.py b/python/google/protobuf/internal/descriptor_cpp2_test.py new file mode 100644 index 00000000..3a3ff298 --- /dev/null +++ b/python/google/protobuf/internal/descriptor_cpp2_test.py @@ -0,0 +1,58 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for google.protobuf.pyext behavior.""" + +__author__ = 'anuraag@google.com (Anuraag Agrawal)' + +import os +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp' +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION'] = '2' + +# We must set the implementation version above before the google3 imports. +# pylint: disable=g-import-not-at-top +from google.apputils import basetest +from google.protobuf.internal import api_implementation +# Run all tests from the original module by putting them in our namespace. +# pylint: disable=wildcard-import +from google.protobuf.internal.descriptor_test import * + + +class ConfirmCppApi2Test(basetest.TestCase): + + def testImplementationSetting(self): + self.assertEqual('cpp', api_implementation.Type()) + self.assertEqual(2, api_implementation.Version()) + + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/internal/descriptor_database_test.py b/python/google/protobuf/internal/descriptor_database_test.py index d0ca7892..856f4723 100644 --- a/python/google/protobuf/internal/descriptor_database_test.py +++ b/python/google/protobuf/internal/descriptor_database_test.py @@ -34,13 +34,13 @@ __author__ = 'matthewtoia@google.com (Matt Toia)' -import unittest +from google.apputils import basetest from google.protobuf import descriptor_pb2 from google.protobuf.internal import factory_test2_pb2 from google.protobuf import descriptor_database -class DescriptorDatabaseTest(unittest.TestCase): +class DescriptorDatabaseTest(basetest.TestCase): def testAdd(self): db = descriptor_database.DescriptorDatabase() @@ -49,15 +49,15 @@ class DescriptorDatabaseTest(unittest.TestCase): db.Add(file_desc_proto) self.assertEquals(file_desc_proto, db.FindFileByName( - 'net/proto2/python/internal/factory_test2.proto')) + 'google/protobuf/internal/factory_test2.proto')) self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( - 'net.proto2.python.internal.Factory2Message')) + 'google.protobuf.python.internal.Factory2Message')) self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( - 'net.proto2.python.internal.Factory2Message.NestedFactory2Message')) + 'google.protobuf.python.internal.Factory2Message.NestedFactory2Message')) self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( - 'net.proto2.python.internal.Factory2Enum')) + 'google.protobuf.python.internal.Factory2Enum')) self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( - 'net.proto2.python.internal.Factory2Message.NestedFactory2Enum')) + 'google.protobuf.python.internal.Factory2Message.NestedFactory2Enum')) if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py index a615d787..7c1ce2e1 100644 --- a/python/google/protobuf/internal/descriptor_pool_test.py +++ b/python/google/protobuf/internal/descriptor_pool_test.py @@ -34,8 +34,15 @@ __author__ = 'matthewtoia@google.com (Matt Toia)' +import os import unittest + +from google.apputils import basetest +from google.protobuf import unittest_pb2 from google.protobuf import descriptor_pb2 +from google.protobuf.internal import api_implementation +from google.protobuf.internal import descriptor_pool_test1_pb2 +from google.protobuf.internal import descriptor_pool_test2_pb2 from google.protobuf.internal import factory_test1_pb2 from google.protobuf.internal import factory_test2_pb2 from google.protobuf import descriptor @@ -43,7 +50,7 @@ from google.protobuf import descriptor_database from google.protobuf import descriptor_pool -class DescriptorPoolTest(unittest.TestCase): +class DescriptorPoolTest(basetest.TestCase): def setUp(self): self.pool = descriptor_pool.DescriptorPool() @@ -55,57 +62,51 @@ class DescriptorPoolTest(unittest.TestCase): self.pool.Add(self.factory_test2_fd) def testFindFileByName(self): - name1 = 'net/proto2/python/internal/factory_test1.proto' + name1 = 'google/protobuf/internal/factory_test1.proto' file_desc1 = self.pool.FindFileByName(name1) self.assertIsInstance(file_desc1, descriptor.FileDescriptor) self.assertEquals(name1, file_desc1.name) - self.assertEquals('net.proto2.python.internal', file_desc1.package) + self.assertEquals('google.protobuf.python.internal', file_desc1.package) self.assertIn('Factory1Message', file_desc1.message_types_by_name) - name2 = 'net/proto2/python/internal/factory_test2.proto' + name2 = 'google/protobuf/internal/factory_test2.proto' file_desc2 = self.pool.FindFileByName(name2) self.assertIsInstance(file_desc2, descriptor.FileDescriptor) self.assertEquals(name2, file_desc2.name) - self.assertEquals('net.proto2.python.internal', file_desc2.package) + self.assertEquals('google.protobuf.python.internal', file_desc2.package) self.assertIn('Factory2Message', file_desc2.message_types_by_name) def testFindFileByNameFailure(self): - try: + with self.assertRaises(KeyError): self.pool.FindFileByName('Does not exist') - self.fail('Expected KeyError') - except KeyError: - pass def testFindFileContainingSymbol(self): file_desc1 = self.pool.FindFileContainingSymbol( - 'net.proto2.python.internal.Factory1Message') + 'google.protobuf.python.internal.Factory1Message') self.assertIsInstance(file_desc1, descriptor.FileDescriptor) - self.assertEquals('net/proto2/python/internal/factory_test1.proto', + self.assertEquals('google/protobuf/internal/factory_test1.proto', file_desc1.name) - self.assertEquals('net.proto2.python.internal', file_desc1.package) + self.assertEquals('google.protobuf.python.internal', file_desc1.package) self.assertIn('Factory1Message', file_desc1.message_types_by_name) file_desc2 = self.pool.FindFileContainingSymbol( - 'net.proto2.python.internal.Factory2Message') + 'google.protobuf.python.internal.Factory2Message') self.assertIsInstance(file_desc2, descriptor.FileDescriptor) - self.assertEquals('net/proto2/python/internal/factory_test2.proto', + self.assertEquals('google/protobuf/internal/factory_test2.proto', file_desc2.name) - self.assertEquals('net.proto2.python.internal', file_desc2.package) + self.assertEquals('google.protobuf.python.internal', file_desc2.package) self.assertIn('Factory2Message', file_desc2.message_types_by_name) def testFindFileContainingSymbolFailure(self): - try: + with self.assertRaises(KeyError): self.pool.FindFileContainingSymbol('Does not exist') - self.fail('Expected KeyError') - except KeyError: - pass def testFindMessageTypeByName(self): msg1 = self.pool.FindMessageTypeByName( - 'net.proto2.python.internal.Factory1Message') + 'google.protobuf.python.internal.Factory1Message') self.assertIsInstance(msg1, descriptor.Descriptor) self.assertEquals('Factory1Message', msg1.name) - self.assertEquals('net.proto2.python.internal.Factory1Message', + self.assertEquals('google.protobuf.python.internal.Factory1Message', msg1.full_name) self.assertEquals(None, msg1.containing_type) @@ -123,10 +124,10 @@ class DescriptorPoolTest(unittest.TestCase): 'nested_factory_1_enum'].enum_type) msg2 = self.pool.FindMessageTypeByName( - 'net.proto2.python.internal.Factory2Message') + 'google.protobuf.python.internal.Factory2Message') self.assertIsInstance(msg2, descriptor.Descriptor) self.assertEquals('Factory2Message', msg2.name) - self.assertEquals('net.proto2.python.internal.Factory2Message', + self.assertEquals('google.protobuf.python.internal.Factory2Message', msg2.full_name) self.assertIsNone(msg2.containing_type) @@ -143,45 +144,57 @@ class DescriptorPoolTest(unittest.TestCase): self.assertEquals(nested_enum2, msg2.fields_by_name[ 'nested_factory_2_enum'].enum_type) - self.assertTrue(msg2.fields_by_name['int_with_default'].has_default) + self.assertTrue(msg2.fields_by_name['int_with_default'].has_default_value) self.assertEquals( 1776, msg2.fields_by_name['int_with_default'].default_value) - self.assertTrue(msg2.fields_by_name['double_with_default'].has_default) + self.assertTrue( + msg2.fields_by_name['double_with_default'].has_default_value) self.assertEquals( 9.99, msg2.fields_by_name['double_with_default'].default_value) - self.assertTrue(msg2.fields_by_name['string_with_default'].has_default) + self.assertTrue( + msg2.fields_by_name['string_with_default'].has_default_value) self.assertEquals( 'hello world', msg2.fields_by_name['string_with_default'].default_value) - self.assertTrue(msg2.fields_by_name['bool_with_default'].has_default) + self.assertTrue(msg2.fields_by_name['bool_with_default'].has_default_value) self.assertFalse(msg2.fields_by_name['bool_with_default'].default_value) - self.assertTrue(msg2.fields_by_name['enum_with_default'].has_default) + self.assertTrue(msg2.fields_by_name['enum_with_default'].has_default_value) self.assertEquals( 1, msg2.fields_by_name['enum_with_default'].default_value) msg3 = self.pool.FindMessageTypeByName( - 'net.proto2.python.internal.Factory2Message.NestedFactory2Message') + 'google.protobuf.python.internal.Factory2Message.NestedFactory2Message') self.assertEquals(nested_msg2, msg3) + self.assertTrue(msg2.fields_by_name['bytes_with_default'].has_default_value) + self.assertEquals( + b'a\xfb\x00c', + msg2.fields_by_name['bytes_with_default'].default_value) + + self.assertEqual(1, len(msg2.oneofs)) + self.assertEqual(1, len(msg2.oneofs_by_name)) + self.assertEqual(2, len(msg2.oneofs[0].fields)) + for name in ['oneof_int', 'oneof_string']: + self.assertEqual(msg2.oneofs[0], + msg2.fields_by_name[name].containing_oneof) + self.assertIn(msg2.fields_by_name[name], msg2.oneofs[0].fields) + def testFindMessageTypeByNameFailure(self): - try: + with self.assertRaises(KeyError): self.pool.FindMessageTypeByName('Does not exist') - self.fail('Expected KeyError') - except KeyError: - pass def testFindEnumTypeByName(self): enum1 = self.pool.FindEnumTypeByName( - 'net.proto2.python.internal.Factory1Enum') + 'google.protobuf.python.internal.Factory1Enum') self.assertIsInstance(enum1, descriptor.EnumDescriptor) self.assertEquals(0, enum1.values_by_name['FACTORY_1_VALUE_0'].number) self.assertEquals(1, enum1.values_by_name['FACTORY_1_VALUE_1'].number) nested_enum1 = self.pool.FindEnumTypeByName( - 'net.proto2.python.internal.Factory1Message.NestedFactory1Enum') + 'google.protobuf.python.internal.Factory1Message.NestedFactory1Enum') self.assertIsInstance(nested_enum1, descriptor.EnumDescriptor) self.assertEquals( 0, nested_enum1.values_by_name['NESTED_FACTORY_1_VALUE_0'].number) @@ -189,13 +202,13 @@ class DescriptorPoolTest(unittest.TestCase): 1, nested_enum1.values_by_name['NESTED_FACTORY_1_VALUE_1'].number) enum2 = self.pool.FindEnumTypeByName( - 'net.proto2.python.internal.Factory2Enum') + 'google.protobuf.python.internal.Factory2Enum') self.assertIsInstance(enum2, descriptor.EnumDescriptor) self.assertEquals(0, enum2.values_by_name['FACTORY_2_VALUE_0'].number) self.assertEquals(1, enum2.values_by_name['FACTORY_2_VALUE_1'].number) nested_enum2 = self.pool.FindEnumTypeByName( - 'net.proto2.python.internal.Factory2Message.NestedFactory2Enum') + 'google.protobuf.python.internal.Factory2Message.NestedFactory2Enum') self.assertIsInstance(nested_enum2, descriptor.EnumDescriptor) self.assertEquals( 0, nested_enum2.values_by_name['NESTED_FACTORY_2_VALUE_0'].number) @@ -203,11 +216,8 @@ class DescriptorPoolTest(unittest.TestCase): 1, nested_enum2.values_by_name['NESTED_FACTORY_2_VALUE_1'].number) def testFindEnumTypeByNameFailure(self): - try: + with self.assertRaises(KeyError): self.pool.FindEnumTypeByName('Does not exist') - self.fail('Expected KeyError') - except KeyError: - pass def testUserDefinedDB(self): db = descriptor_database.DescriptorDatabase() @@ -216,5 +226,339 @@ class DescriptorPoolTest(unittest.TestCase): db.Add(self.factory_test2_fd) self.testFindMessageTypeByName() + def testComplexNesting(self): + test1_desc = descriptor_pb2.FileDescriptorProto.FromString( + descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb) + test2_desc = descriptor_pb2.FileDescriptorProto.FromString( + descriptor_pool_test2_pb2.DESCRIPTOR.serialized_pb) + self.pool.Add(test1_desc) + self.pool.Add(test2_desc) + TEST1_FILE.CheckFile(self, self.pool) + TEST2_FILE.CheckFile(self, self.pool) + + + +class ProtoFile(object): + + def __init__(self, name, package, messages, dependencies=None): + self.name = name + self.package = package + self.messages = messages + self.dependencies = dependencies or [] + + def CheckFile(self, test, pool): + file_desc = pool.FindFileByName(self.name) + test.assertEquals(self.name, file_desc.name) + test.assertEquals(self.package, file_desc.package) + dependencies_names = [f.name for f in file_desc.dependencies] + test.assertEqual(self.dependencies, dependencies_names) + for name, msg_type in self.messages.items(): + msg_type.CheckType(test, None, name, file_desc) + + +class EnumType(object): + + def __init__(self, values): + self.values = values + + def CheckType(self, test, msg_desc, name, file_desc): + enum_desc = msg_desc.enum_types_by_name[name] + test.assertEqual(name, enum_desc.name) + expected_enum_full_name = '.'.join([msg_desc.full_name, name]) + test.assertEqual(expected_enum_full_name, enum_desc.full_name) + test.assertEqual(msg_desc, enum_desc.containing_type) + test.assertEqual(file_desc, enum_desc.file) + for index, (value, number) in enumerate(self.values): + value_desc = enum_desc.values_by_name[value] + test.assertEqual(value, value_desc.name) + test.assertEqual(index, value_desc.index) + test.assertEqual(number, value_desc.number) + test.assertEqual(enum_desc, value_desc.type) + test.assertIn(value, msg_desc.enum_values_by_name) + + +class MessageType(object): + + def __init__(self, type_dict, field_list, is_extendable=False, + extensions=None): + self.type_dict = type_dict + self.field_list = field_list + self.is_extendable = is_extendable + self.extensions = extensions or [] + + def CheckType(self, test, containing_type_desc, name, file_desc): + if containing_type_desc is None: + desc = file_desc.message_types_by_name[name] + expected_full_name = '.'.join([file_desc.package, name]) + else: + desc = containing_type_desc.nested_types_by_name[name] + expected_full_name = '.'.join([containing_type_desc.full_name, name]) + + test.assertEqual(name, desc.name) + test.assertEqual(expected_full_name, desc.full_name) + test.assertEqual(containing_type_desc, desc.containing_type) + test.assertEqual(desc.file, file_desc) + test.assertEqual(self.is_extendable, desc.is_extendable) + for name, subtype in self.type_dict.items(): + subtype.CheckType(test, desc, name, file_desc) + + for index, (name, field) in enumerate(self.field_list): + field.CheckField(test, desc, name, index) + + for index, (name, field) in enumerate(self.extensions): + field.CheckField(test, desc, name, index) + + +class EnumField(object): + + def __init__(self, number, type_name, default_value): + self.number = number + self.type_name = type_name + self.default_value = default_value + + def CheckField(self, test, msg_desc, name, index): + field_desc = msg_desc.fields_by_name[name] + enum_desc = msg_desc.enum_types_by_name[self.type_name] + test.assertEqual(name, field_desc.name) + expected_field_full_name = '.'.join([msg_desc.full_name, name]) + test.assertEqual(expected_field_full_name, field_desc.full_name) + test.assertEqual(index, field_desc.index) + test.assertEqual(self.number, field_desc.number) + test.assertEqual(descriptor.FieldDescriptor.TYPE_ENUM, field_desc.type) + test.assertEqual(descriptor.FieldDescriptor.CPPTYPE_ENUM, + field_desc.cpp_type) + test.assertTrue(field_desc.has_default_value) + test.assertEqual(enum_desc.values_by_name[self.default_value].index, + field_desc.default_value) + test.assertEqual(msg_desc, field_desc.containing_type) + test.assertEqual(enum_desc, field_desc.enum_type) + + +class MessageField(object): + + def __init__(self, number, type_name): + self.number = number + self.type_name = type_name + + def CheckField(self, test, msg_desc, name, index): + field_desc = msg_desc.fields_by_name[name] + field_type_desc = msg_desc.nested_types_by_name[self.type_name] + test.assertEqual(name, field_desc.name) + expected_field_full_name = '.'.join([msg_desc.full_name, name]) + test.assertEqual(expected_field_full_name, field_desc.full_name) + test.assertEqual(index, field_desc.index) + test.assertEqual(self.number, field_desc.number) + test.assertEqual(descriptor.FieldDescriptor.TYPE_MESSAGE, field_desc.type) + test.assertEqual(descriptor.FieldDescriptor.CPPTYPE_MESSAGE, + field_desc.cpp_type) + test.assertFalse(field_desc.has_default_value) + test.assertEqual(msg_desc, field_desc.containing_type) + test.assertEqual(field_type_desc, field_desc.message_type) + + +class StringField(object): + + def __init__(self, number, default_value): + self.number = number + self.default_value = default_value + + def CheckField(self, test, msg_desc, name, index): + field_desc = msg_desc.fields_by_name[name] + test.assertEqual(name, field_desc.name) + expected_field_full_name = '.'.join([msg_desc.full_name, name]) + test.assertEqual(expected_field_full_name, field_desc.full_name) + test.assertEqual(index, field_desc.index) + test.assertEqual(self.number, field_desc.number) + test.assertEqual(descriptor.FieldDescriptor.TYPE_STRING, field_desc.type) + test.assertEqual(descriptor.FieldDescriptor.CPPTYPE_STRING, + field_desc.cpp_type) + test.assertTrue(field_desc.has_default_value) + test.assertEqual(self.default_value, field_desc.default_value) + + +class ExtensionField(object): + + def __init__(self, number, extended_type): + self.number = number + self.extended_type = extended_type + + def CheckField(self, test, msg_desc, name, index): + field_desc = msg_desc.extensions_by_name[name] + test.assertEqual(name, field_desc.name) + expected_field_full_name = '.'.join([msg_desc.full_name, name]) + test.assertEqual(expected_field_full_name, field_desc.full_name) + test.assertEqual(self.number, field_desc.number) + test.assertEqual(index, field_desc.index) + test.assertEqual(descriptor.FieldDescriptor.TYPE_MESSAGE, field_desc.type) + test.assertEqual(descriptor.FieldDescriptor.CPPTYPE_MESSAGE, + field_desc.cpp_type) + test.assertFalse(field_desc.has_default_value) + test.assertTrue(field_desc.is_extension) + test.assertEqual(msg_desc, field_desc.extension_scope) + test.assertEqual(msg_desc, field_desc.message_type) + test.assertEqual(self.extended_type, field_desc.containing_type.name) + + +class AddDescriptorTest(basetest.TestCase): + + def _TestMessage(self, prefix): + pool = descriptor_pool.DescriptorPool() + pool.AddDescriptor(unittest_pb2.TestAllTypes.DESCRIPTOR) + self.assertEquals( + 'protobuf_unittest.TestAllTypes', + pool.FindMessageTypeByName( + prefix + 'protobuf_unittest.TestAllTypes').full_name) + + # AddDescriptor is not recursive. + with self.assertRaises(KeyError): + pool.FindMessageTypeByName( + prefix + 'protobuf_unittest.TestAllTypes.NestedMessage') + + pool.AddDescriptor(unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR) + self.assertEquals( + 'protobuf_unittest.TestAllTypes.NestedMessage', + pool.FindMessageTypeByName( + prefix + 'protobuf_unittest.TestAllTypes.NestedMessage').full_name) + + # Files are implicitly also indexed when messages are added. + self.assertEquals( + 'google/protobuf/unittest.proto', + pool.FindFileByName( + 'google/protobuf/unittest.proto').name) + + self.assertEquals( + 'google/protobuf/unittest.proto', + pool.FindFileContainingSymbol( + prefix + 'protobuf_unittest.TestAllTypes.NestedMessage').name) + + def testMessage(self): + self._TestMessage('') + self._TestMessage('.') + + def _TestEnum(self, prefix): + pool = descriptor_pool.DescriptorPool() + pool.AddEnumDescriptor(unittest_pb2.ForeignEnum.DESCRIPTOR) + self.assertEquals( + 'protobuf_unittest.ForeignEnum', + pool.FindEnumTypeByName( + prefix + 'protobuf_unittest.ForeignEnum').full_name) + + # AddEnumDescriptor is not recursive. + with self.assertRaises(KeyError): + pool.FindEnumTypeByName( + prefix + 'protobuf_unittest.ForeignEnum.NestedEnum') + + pool.AddEnumDescriptor(unittest_pb2.TestAllTypes.NestedEnum.DESCRIPTOR) + self.assertEquals( + 'protobuf_unittest.TestAllTypes.NestedEnum', + pool.FindEnumTypeByName( + prefix + 'protobuf_unittest.TestAllTypes.NestedEnum').full_name) + + # Files are implicitly also indexed when enums are added. + self.assertEquals( + 'google/protobuf/unittest.proto', + pool.FindFileByName( + 'google/protobuf/unittest.proto').name) + + self.assertEquals( + 'google/protobuf/unittest.proto', + pool.FindFileContainingSymbol( + prefix + 'protobuf_unittest.TestAllTypes.NestedEnum').name) + + def testEnum(self): + self._TestEnum('') + self._TestEnum('.') + + def testFile(self): + pool = descriptor_pool.DescriptorPool() + pool.AddFileDescriptor(unittest_pb2.DESCRIPTOR) + self.assertEquals( + 'google/protobuf/unittest.proto', + pool.FindFileByName( + 'google/protobuf/unittest.proto').name) + + # AddFileDescriptor is not recursive; messages and enums within files must + # be explicitly registered. + with self.assertRaises(KeyError): + pool.FindFileContainingSymbol( + 'protobuf_unittest.TestAllTypes') + + +TEST1_FILE = ProtoFile( + 'google/protobuf/internal/descriptor_pool_test1.proto', + 'google.protobuf.python.internal', + { + 'DescriptorPoolTest1': MessageType({ + 'NestedEnum': EnumType([('ALPHA', 1), ('BETA', 2)]), + 'NestedMessage': MessageType({ + 'NestedEnum': EnumType([('EPSILON', 5), ('ZETA', 6)]), + 'DeepNestedMessage': MessageType({ + 'NestedEnum': EnumType([('ETA', 7), ('THETA', 8)]), + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'ETA')), + ('nested_field', StringField(2, 'theta')), + ]), + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'ZETA')), + ('nested_field', StringField(2, 'beta')), + ('deep_nested_message', MessageField(3, 'DeepNestedMessage')), + ]) + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'BETA')), + ('nested_message', MessageField(2, 'NestedMessage')), + ], is_extendable=True), + + 'DescriptorPoolTest2': MessageType({ + 'NestedEnum': EnumType([('GAMMA', 3), ('DELTA', 4)]), + 'NestedMessage': MessageType({ + 'NestedEnum': EnumType([('IOTA', 9), ('KAPPA', 10)]), + 'DeepNestedMessage': MessageType({ + 'NestedEnum': EnumType([('LAMBDA', 11), ('MU', 12)]), + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'MU')), + ('nested_field', StringField(2, 'lambda')), + ]), + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'IOTA')), + ('nested_field', StringField(2, 'delta')), + ('deep_nested_message', MessageField(3, 'DeepNestedMessage')), + ]) + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'GAMMA')), + ('nested_message', MessageField(2, 'NestedMessage')), + ]), + }) + + +TEST2_FILE = ProtoFile( + 'google/protobuf/internal/descriptor_pool_test2.proto', + 'google.protobuf.python.internal', + { + 'DescriptorPoolTest3': MessageType({ + 'NestedEnum': EnumType([('NU', 13), ('XI', 14)]), + 'NestedMessage': MessageType({ + 'NestedEnum': EnumType([('OMICRON', 15), ('PI', 16)]), + 'DeepNestedMessage': MessageType({ + 'NestedEnum': EnumType([('RHO', 17), ('SIGMA', 18)]), + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'RHO')), + ('nested_field', StringField(2, 'sigma')), + ]), + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'PI')), + ('nested_field', StringField(2, 'nu')), + ('deep_nested_message', MessageField(3, 'DeepNestedMessage')), + ]) + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'XI')), + ('nested_message', MessageField(2, 'NestedMessage')), + ], extensions=[ + ('descriptor_pool_test', + ExtensionField(1001, 'DescriptorPoolTest1')), + ]), + }, + dependencies=['google/protobuf/internal/descriptor_pool_test1.proto']) + + if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/descriptor_pool_test1.proto b/python/google/protobuf/internal/descriptor_pool_test1.proto new file mode 100644 index 00000000..c11dcc0e --- /dev/null +++ b/python/google/protobuf/internal/descriptor_pool_test1.proto @@ -0,0 +1,94 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package google.protobuf.python.internal; + + +message DescriptorPoolTest1 { + extensions 1000 to max; + + enum NestedEnum { + ALPHA = 1; + BETA = 2; + } + + optional NestedEnum nested_enum = 1 [default = BETA]; + + message NestedMessage { + enum NestedEnum { + EPSILON = 5; + ZETA = 6; + } + optional NestedEnum nested_enum = 1 [default = ZETA]; + optional string nested_field = 2 [default = "beta"]; + optional DeepNestedMessage deep_nested_message = 3; + + message DeepNestedMessage { + enum NestedEnum { + ETA = 7; + THETA = 8; + } + optional NestedEnum nested_enum = 1 [default = ETA]; + optional string nested_field = 2 [default = "theta"]; + } + } + + optional NestedMessage nested_message = 2; +} + +message DescriptorPoolTest2 { + enum NestedEnum { + GAMMA = 3; + DELTA = 4; + } + + optional NestedEnum nested_enum = 1 [default = GAMMA]; + + message NestedMessage { + enum NestedEnum { + IOTA = 9; + KAPPA = 10; + } + optional NestedEnum nested_enum = 1 [default = IOTA]; + optional string nested_field = 2 [default = "delta"]; + optional DeepNestedMessage deep_nested_message = 3; + + message DeepNestedMessage { + enum NestedEnum { + LAMBDA = 11; + MU = 12; + } + optional NestedEnum nested_enum = 1 [default = MU]; + optional string nested_field = 2 [default = "lambda"]; + } + } + + optional NestedMessage nested_message = 2; +} diff --git a/python/google/protobuf/internal/descriptor_pool_test2.proto b/python/google/protobuf/internal/descriptor_pool_test2.proto new file mode 100644 index 00000000..d97d39b4 --- /dev/null +++ b/python/google/protobuf/internal/descriptor_pool_test2.proto @@ -0,0 +1,70 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package google.protobuf.python.internal; + +import "google/protobuf/internal/descriptor_pool_test1.proto"; + + +message DescriptorPoolTest3 { + + extend DescriptorPoolTest1 { + optional DescriptorPoolTest3 descriptor_pool_test = 1001; + } + + enum NestedEnum { + NU = 13; + XI = 14; + } + + optional NestedEnum nested_enum = 1 [default = XI]; + + message NestedMessage { + enum NestedEnum { + OMICRON = 15; + PI = 16; + } + optional NestedEnum nested_enum = 1 [default = PI]; + optional string nested_field = 2 [default = "nu"]; + optional DeepNestedMessage deep_nested_message = 3; + + message DeepNestedMessage { + enum NestedEnum { + RHO = 17; + SIGMA = 18; + } + optional NestedEnum nested_enum = 1 [default = RHO]; + optional string nested_field = 2 [default = "sigma"]; + } + } + + optional NestedMessage nested_message = 2; +} + diff --git a/python/google/protobuf/internal/descriptor_python_test.py b/python/google/protobuf/internal/descriptor_python_test.py new file mode 100644 index 00000000..b3a15710 --- /dev/null +++ b/python/google/protobuf/internal/descriptor_python_test.py @@ -0,0 +1,54 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Unittest for descriptor.py for the pure Python implementation.""" + +import os +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' + +# We must set the implementation version above before the google3 imports. +# pylint: disable=g-import-not-at-top +from google.apputils import basetest +from google.protobuf.internal import api_implementation +# Run all tests from the original module by putting them in our namespace. +# pylint: disable=wildcard-import +from google.protobuf.internal.descriptor_test import * + + +class ConfirmPurePythonTest(basetest.TestCase): + + def testImplementationSetting(self): + self.assertEqual('python', api_implementation.Type()) + + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py index c74f882e..d20d9457 100755 --- a/python/google/protobuf/internal/descriptor_test.py +++ b/python/google/protobuf/internal/descriptor_test.py @@ -34,7 +34,7 @@ __author__ = 'robinson@google.com (Will Robinson)' -import unittest +from google.apputils import basetest from google.protobuf import unittest_custom_options_pb2 from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_pb2 @@ -48,7 +48,7 @@ name: 'TestEmptyMessage' """ -class DescriptorTest(unittest.TestCase): +class DescriptorTest(basetest.TestCase): def setUp(self): self.my_file = descriptor.FileDescriptor( @@ -244,7 +244,7 @@ class DescriptorTest(unittest.TestCase): unittest_custom_options_pb2.double_opt]) self.assertEqual("Hello, \"World\"", message_options.Extensions[ unittest_custom_options_pb2.string_opt]) - self.assertEqual("Hello\0World", message_options.Extensions[ + self.assertEqual(b"Hello\0World", message_options.Extensions[ unittest_custom_options_pb2.bytes_opt]) dummy_enum = unittest_custom_options_pb2.DummyMessageContainingEnum self.assertEqual( @@ -395,7 +395,7 @@ class DescriptorTest(unittest.TestCase): self.assertEqual(self.my_file.package, 'protobuf_unittest') -class DescriptorCopyToProtoTest(unittest.TestCase): +class DescriptorCopyToProtoTest(basetest.TestCase): """Tests for CopyTo functions of Descriptor.""" def _AssertProtoEqual(self, actual_proto, expected_class, expected_ascii): @@ -530,47 +530,49 @@ class DescriptorCopyToProtoTest(unittest.TestCase): descriptor_pb2.DescriptorProto, TEST_MESSAGE_WITH_SEVERAL_EXTENSIONS_ASCII) - def testCopyToProto_FileDescriptor(self): - UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII = (""" - name: 'google/protobuf/unittest_import.proto' - package: 'protobuf_unittest_import' - dependency: 'google/protobuf/unittest_import_public.proto' - message_type: < - name: 'ImportMessage' - field: < - name: 'd' - number: 1 - label: 1 # Optional - type: 5 # TYPE_INT32 - > - > - """ + - """enum_type: < - name: 'ImportEnum' - value: < - name: 'IMPORT_FOO' - number: 7 - > - value: < - name: 'IMPORT_BAR' - number: 8 - > - value: < - name: 'IMPORT_BAZ' - number: 9 - > - > - options: < - java_package: 'com.google.protobuf.test' - optimize_for: 1 # SPEED - > - public_dependency: 0 - """) - - self._InternalTestCopyToProto( - unittest_import_pb2.DESCRIPTOR, - descriptor_pb2.FileDescriptorProto, - UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII) + # Disable this test so we can make changes to the proto file. + # TODO(xiaofeng): Enable this test after cl/55530659 is submitted. + # + # def testCopyToProto_FileDescriptor(self): + # UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII = (""" + # name: 'google/protobuf/unittest_import.proto' + # package: 'protobuf_unittest_import' + # dependency: 'google/protobuf/unittest_import_public.proto' + # message_type: < + # name: 'ImportMessage' + # field: < + # name: 'd' + # number: 1 + # label: 1 # Optional + # type: 5 # TYPE_INT32 + # > + # > + # """ + + # """enum_type: < + # name: 'ImportEnum' + # value: < + # name: 'IMPORT_FOO' + # number: 7 + # > + # value: < + # name: 'IMPORT_BAR' + # number: 8 + # > + # value: < + # name: 'IMPORT_BAZ' + # number: 9 + # > + # > + # options: < + # java_package: 'com.google.protobuf.test' + # optimize_for: 1 # SPEED + # > + # public_dependency: 0 + # """) + # self._InternalTestCopyToProto( + # unittest_import_pb2.DESCRIPTOR, + # descriptor_pb2.FileDescriptorProto, + # UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII) def testCopyToProto_ServiceDescriptor(self): TEST_SERVICE_ASCII = """ @@ -586,28 +588,82 @@ class DescriptorCopyToProtoTest(unittest.TestCase): output_type: '.protobuf_unittest.BarResponse' > """ - self._InternalTestCopyToProto( unittest_pb2.TestService.DESCRIPTOR, descriptor_pb2.ServiceDescriptorProto, TEST_SERVICE_ASCII) -class MakeDescriptorTest(unittest.TestCase): +class MakeDescriptorTest(basetest.TestCase): + + def testMakeDescriptorWithNestedFields(self): + file_descriptor_proto = descriptor_pb2.FileDescriptorProto() + file_descriptor_proto.name = 'Foo2' + message_type = file_descriptor_proto.message_type.add() + message_type.name = file_descriptor_proto.name + nested_type = message_type.nested_type.add() + nested_type.name = 'Sub' + enum_type = nested_type.enum_type.add() + enum_type.name = 'FOO' + enum_type_val = enum_type.value.add() + enum_type_val.name = 'BAR' + enum_type_val.number = 3 + field = message_type.field.add() + field.number = 1 + field.name = 'uint64_field' + field.label = descriptor.FieldDescriptor.LABEL_REQUIRED + field.type = descriptor.FieldDescriptor.TYPE_UINT64 + field = message_type.field.add() + field.number = 2 + field.name = 'nested_message_field' + field.label = descriptor.FieldDescriptor.LABEL_REQUIRED + field.type = descriptor.FieldDescriptor.TYPE_MESSAGE + field.type_name = 'Sub' + enum_field = nested_type.field.add() + enum_field.number = 2 + enum_field.name = 'bar_field' + enum_field.label = descriptor.FieldDescriptor.LABEL_REQUIRED + enum_field.type = descriptor.FieldDescriptor.TYPE_ENUM + enum_field.type_name = 'Foo2.Sub.FOO' + + result = descriptor.MakeDescriptor(message_type) + self.assertEqual(result.fields[0].cpp_type, + descriptor.FieldDescriptor.CPPTYPE_UINT64) + self.assertEqual(result.fields[1].cpp_type, + descriptor.FieldDescriptor.CPPTYPE_MESSAGE) + self.assertEqual(result.fields[1].message_type.containing_type, + result) + self.assertEqual(result.nested_types[0].fields[0].full_name, + 'Foo2.Sub.bar_field') + self.assertEqual(result.nested_types[0].fields[0].enum_type, + result.nested_types[0].enum_types[0]) + def testMakeDescriptorWithUnsignedIntField(self): file_descriptor_proto = descriptor_pb2.FileDescriptorProto() file_descriptor_proto.name = 'Foo' message_type = file_descriptor_proto.message_type.add() message_type.name = file_descriptor_proto.name + enum_type = message_type.enum_type.add() + enum_type.name = 'FOO' + enum_type_val = enum_type.value.add() + enum_type_val.name = 'BAR' + enum_type_val.number = 3 field = message_type.field.add() field.number = 1 field.name = 'uint64_field' field.label = descriptor.FieldDescriptor.LABEL_REQUIRED field.type = descriptor.FieldDescriptor.TYPE_UINT64 + enum_field = message_type.field.add() + enum_field.number = 2 + enum_field.name = 'bar_field' + enum_field.label = descriptor.FieldDescriptor.LABEL_REQUIRED + enum_field.type = descriptor.FieldDescriptor.TYPE_ENUM + enum_field.type_name = 'Foo.FOO' + result = descriptor.MakeDescriptor(message_type) self.assertEqual(result.fields[0].cpp_type, descriptor.FieldDescriptor.CPPTYPE_UINT64) if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/encoder.py b/python/google/protobuf/internal/encoder.py index 777975e8..0a7c0417 100755 --- a/python/google/protobuf/internal/encoder.py +++ b/python/google/protobuf/internal/encoder.py @@ -28,6 +28,10 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#PY25 compatible for GAE. +# +# Copyright 2009 Google Inc. All Rights Reserved. + """Code for encoding protocol message primitives. Contains the logic for encoding every logical protocol field type @@ -67,6 +71,8 @@ sizer rather than when calling them. In particular: __author__ = 'kenton@google.com (Kenton Varda)' import struct +import sys ##PY25 +_PY2 = sys.version_info[0] < 3 ##PY25 from google.protobuf.internal import wire_format @@ -340,7 +346,8 @@ def MessageSetItemSizer(field_number): def _VarintEncoder(): """Return an encoder for a basic varint value (does not include tag).""" - local_chr = chr + local_chr = _PY2 and chr or (lambda x: bytes((x,))) ##PY25 +##!PY25 local_chr = chr if bytes is str else lambda x: bytes((x,)) def EncodeVarint(write, value): bits = value & 0x7f value >>= 7 @@ -357,7 +364,8 @@ def _SignedVarintEncoder(): """Return an encoder for a basic signed varint value (does not include tag).""" - local_chr = chr + local_chr = _PY2 and chr or (lambda x: bytes((x,))) ##PY25 +##!PY25 local_chr = chr if bytes is str else lambda x: bytes((x,)) def EncodeSignedVarint(write, value): if value < 0: value += (1 << 64) @@ -382,7 +390,8 @@ def _VarintBytes(value): pieces = [] _EncodeVarint(pieces.append, value) - return "".join(pieces) + return "".encode("latin1").join(pieces) ##PY25 +##!PY25 return b"".join(pieces) def TagBytes(field_number, wire_type): @@ -520,26 +529,33 @@ def _FloatingPointEncoder(wire_type, format): format: The format string to pass to struct.pack(). """ + b = _PY2 and (lambda x:x) or (lambda x:x.encode('latin1')) ##PY25 value_size = struct.calcsize(format) if value_size == 4: def EncodeNonFiniteOrRaise(write, value): # Remember that the serialized form uses little-endian byte order. if value == _POS_INF: - write('\x00\x00\x80\x7F') + write(b('\x00\x00\x80\x7F')) ##PY25 +##!PY25 write(b'\x00\x00\x80\x7F') elif value == _NEG_INF: - write('\x00\x00\x80\xFF') + write(b('\x00\x00\x80\xFF')) ##PY25 +##!PY25 write(b'\x00\x00\x80\xFF') elif value != value: # NaN - write('\x00\x00\xC0\x7F') + write(b('\x00\x00\xC0\x7F')) ##PY25 +##!PY25 write(b'\x00\x00\xC0\x7F') else: raise elif value_size == 8: def EncodeNonFiniteOrRaise(write, value): if value == _POS_INF: - write('\x00\x00\x00\x00\x00\x00\xF0\x7F') + write(b('\x00\x00\x00\x00\x00\x00\xF0\x7F')) ##PY25 +##!PY25 write(b'\x00\x00\x00\x00\x00\x00\xF0\x7F') elif value == _NEG_INF: - write('\x00\x00\x00\x00\x00\x00\xF0\xFF') + write(b('\x00\x00\x00\x00\x00\x00\xF0\xFF')) ##PY25 +##!PY25 write(b'\x00\x00\x00\x00\x00\x00\xF0\xFF') elif value != value: # NaN - write('\x00\x00\x00\x00\x00\x00\xF8\x7F') + write(b('\x00\x00\x00\x00\x00\x00\xF8\x7F')) ##PY25 +##!PY25 write(b'\x00\x00\x00\x00\x00\x00\xF8\x7F') else: raise else: @@ -615,8 +631,10 @@ DoubleEncoder = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED64, '= 3: + self.assertEqual(str(message), 'optional_double: 0.12345678912345678\n') + else: + self.assertEqual(str(message), 'optional_double: 0.123456789123\n') + + def testUnknownFieldPrinting(self): + populated = unittest_pb2.TestAllTypes() + test_util.SetAllNonLazyFields(populated) + empty = unittest_pb2.TestEmptyMessage() + empty.ParseFromString(populated.SerializeToString()) + self.assertEqual(str(empty), '') + def testSortingRepeatedScalarFieldsDefaultComparator(self): """Check some different types with the default comparator.""" message = unittest_pb2.TestAllTypes() @@ -332,13 +366,13 @@ class MessageTest(unittest.TestCase): self.assertEqual(message.repeated_string[1], 'b') self.assertEqual(message.repeated_string[2], 'c') - message.repeated_bytes.append('a') - message.repeated_bytes.append('c') - message.repeated_bytes.append('b') + message.repeated_bytes.append(b'a') + message.repeated_bytes.append(b'c') + message.repeated_bytes.append(b'b') message.repeated_bytes.sort() - self.assertEqual(message.repeated_bytes[0], 'a') - self.assertEqual(message.repeated_bytes[1], 'b') - self.assertEqual(message.repeated_bytes[2], 'c') + self.assertEqual(message.repeated_bytes[0], b'a') + self.assertEqual(message.repeated_bytes[1], b'b') + self.assertEqual(message.repeated_bytes[2], b'c') def testSortingRepeatedScalarFieldsCustomComparator(self): """Check some different types with custom comparator.""" @@ -347,7 +381,7 @@ class MessageTest(unittest.TestCase): message.repeated_int32.append(-3) message.repeated_int32.append(-2) message.repeated_int32.append(-1) - message.repeated_int32.sort(lambda x,y: cmp(abs(x), abs(y))) + message.repeated_int32.sort(key=abs) self.assertEqual(message.repeated_int32[0], -1) self.assertEqual(message.repeated_int32[1], -2) self.assertEqual(message.repeated_int32[2], -3) @@ -355,7 +389,7 @@ class MessageTest(unittest.TestCase): message.repeated_string.append('aaa') message.repeated_string.append('bb') message.repeated_string.append('c') - message.repeated_string.sort(lambda x,y: cmp(len(x), len(y))) + message.repeated_string.sort(key=len) self.assertEqual(message.repeated_string[0], 'c') self.assertEqual(message.repeated_string[1], 'bb') self.assertEqual(message.repeated_string[2], 'aaa') @@ -370,7 +404,7 @@ class MessageTest(unittest.TestCase): message.repeated_nested_message.add().bb = 6 message.repeated_nested_message.add().bb = 5 message.repeated_nested_message.add().bb = 4 - message.repeated_nested_message.sort(lambda x,y: cmp(x.bb, y.bb)) + message.repeated_nested_message.sort(key=operator.attrgetter('bb')) self.assertEqual(message.repeated_nested_message[0].bb, 1) self.assertEqual(message.repeated_nested_message[1].bb, 2) self.assertEqual(message.repeated_nested_message[2].bb, 3) @@ -396,6 +430,7 @@ class MessageTest(unittest.TestCase): message.repeated_nested_message.sort(key=get_bb, reverse=True) self.assertEqual([k.bb for k in message.repeated_nested_message], [6, 5, 4, 3, 2, 1]) + if sys.version_info.major >= 3: return # No cmp sorting in PY3. message.repeated_nested_message.sort(sort_function=cmp_bb) self.assertEqual([k.bb for k in message.repeated_nested_message], [1, 2, 3, 4, 5, 6]) @@ -407,7 +442,6 @@ class MessageTest(unittest.TestCase): """Check sorting a scalar field using list.sort() arguments.""" message = unittest_pb2.TestAllTypes() - abs_cmp = lambda a, b: cmp(abs(a), abs(b)) message.repeated_int32.append(-3) message.repeated_int32.append(-2) message.repeated_int32.append(-1) @@ -415,12 +449,13 @@ class MessageTest(unittest.TestCase): self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) message.repeated_int32.sort(key=abs, reverse=True) self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) - message.repeated_int32.sort(sort_function=abs_cmp) - self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) - message.repeated_int32.sort(cmp=abs_cmp, reverse=True) - self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) + if sys.version_info.major < 3: # No cmp sorting in PY3. + abs_cmp = lambda a, b: cmp(abs(a), abs(b)) + message.repeated_int32.sort(sort_function=abs_cmp) + self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) + message.repeated_int32.sort(cmp=abs_cmp, reverse=True) + self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) - len_cmp = lambda a, b: cmp(len(a), len(b)) message.repeated_string.append('aaa') message.repeated_string.append('bb') message.repeated_string.append('c') @@ -428,10 +463,47 @@ class MessageTest(unittest.TestCase): self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) message.repeated_string.sort(key=len, reverse=True) self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) - message.repeated_string.sort(sort_function=len_cmp) - self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) - message.repeated_string.sort(cmp=len_cmp, reverse=True) - self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) + if sys.version_info.major < 3: # No cmp sorting in PY3. + len_cmp = lambda a, b: cmp(len(a), len(b)) + message.repeated_string.sort(sort_function=len_cmp) + self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) + message.repeated_string.sort(cmp=len_cmp, reverse=True) + self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) + + def testRepeatedFieldsComparable(self): + m1 = unittest_pb2.TestAllTypes() + m2 = unittest_pb2.TestAllTypes() + m1.repeated_int32.append(0) + m1.repeated_int32.append(1) + m1.repeated_int32.append(2) + m2.repeated_int32.append(0) + m2.repeated_int32.append(1) + m2.repeated_int32.append(2) + m1.repeated_nested_message.add().bb = 1 + m1.repeated_nested_message.add().bb = 2 + m1.repeated_nested_message.add().bb = 3 + m2.repeated_nested_message.add().bb = 1 + m2.repeated_nested_message.add().bb = 2 + m2.repeated_nested_message.add().bb = 3 + + if sys.version_info.major >= 3: return # No cmp() in PY3. + + # These comparisons should not raise errors. + _ = m1 < m2 + _ = m1.repeated_nested_message < m2.repeated_nested_message + + # Make sure cmp always works. If it wasn't defined, these would be + # id() comparisons and would all fail. + self.assertEqual(cmp(m1, m2), 0) + self.assertEqual(cmp(m1.repeated_int32, m2.repeated_int32), 0) + self.assertEqual(cmp(m1.repeated_int32, [0, 1, 2]), 0) + self.assertEqual(cmp(m1.repeated_nested_message, + m2.repeated_nested_message), 0) + with self.assertRaises(TypeError): + # Can't compare repeated composite containers to lists. + cmp(m1.repeated_nested_message, m2.repeated_nested_message[:]) + + # TODO(anuraag): Implement extensiondict comparison in C++ and then add test def testParsingMerge(self): """Check the merge behavior when a required or optional field appears @@ -482,6 +554,87 @@ class MessageTest(unittest.TestCase): self.assertEqual(len(parsing_merge.Extensions[ unittest_pb2.TestParsingMerge.repeated_ext]), 3) + def ensureNestedMessageExists(self, msg, attribute): + """Make sure that a nested message object exists. + + As soon as a nested message attribute is accessed, it will be present in the + _fields dict, without being marked as actually being set. + """ + getattr(msg, attribute) + self.assertFalse(msg.HasField(attribute)) + + def testOneofGetCaseNonexistingField(self): + m = unittest_pb2.TestAllTypes() + self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field') + + def testOneofSemantics(self): + m = unittest_pb2.TestAllTypes() + self.assertIs(None, m.WhichOneof('oneof_field')) + + m.oneof_uint32 = 11 + self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) + self.assertTrue(m.HasField('oneof_uint32')) + + m.oneof_string = u'foo' + self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) + self.assertFalse(m.HasField('oneof_uint32')) + self.assertTrue(m.HasField('oneof_string')) + + m.oneof_nested_message.bb = 11 + self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field')) + self.assertFalse(m.HasField('oneof_string')) + self.assertTrue(m.HasField('oneof_nested_message')) + + m.oneof_bytes = b'bb' + self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field')) + self.assertFalse(m.HasField('oneof_nested_message')) + self.assertTrue(m.HasField('oneof_bytes')) + + def testOneofCompositeFieldReadAccess(self): + m = unittest_pb2.TestAllTypes() + m.oneof_uint32 = 11 + + self.ensureNestedMessageExists(m, 'oneof_nested_message') + self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) + self.assertEqual(11, m.oneof_uint32) + + def testOneofHasField(self): + m = unittest_pb2.TestAllTypes() + self.assertFalse(m.HasField('oneof_field')) + m.oneof_uint32 = 11 + self.assertTrue(m.HasField('oneof_field')) + m.oneof_bytes = b'bb' + self.assertTrue(m.HasField('oneof_field')) + m.ClearField('oneof_bytes') + self.assertFalse(m.HasField('oneof_field')) + + def testOneofClearField(self): + m = unittest_pb2.TestAllTypes() + m.oneof_uint32 = 11 + m.ClearField('oneof_field') + self.assertFalse(m.HasField('oneof_field')) + self.assertFalse(m.HasField('oneof_uint32')) + self.assertIs(None, m.WhichOneof('oneof_field')) + + def testOneofClearSetField(self): + m = unittest_pb2.TestAllTypes() + m.oneof_uint32 = 11 + m.ClearField('oneof_uint32') + self.assertFalse(m.HasField('oneof_field')) + self.assertFalse(m.HasField('oneof_uint32')) + self.assertIs(None, m.WhichOneof('oneof_field')) + + def testOneofClearUnsetField(self): + m = unittest_pb2.TestAllTypes() + m.oneof_uint32 = 11 + self.ensureNestedMessageExists(m, 'oneof_nested_message') + m.ClearField('oneof_nested_message') + self.assertEqual(11, m.oneof_uint32) + self.assertTrue(m.HasField('oneof_field')) + self.assertTrue(m.HasField('oneof_uint32')) + self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) + + def testSortEmptyRepeatedCompositeContainer(self): """Exercise a scenario that has led to segfaults in the past. @@ -489,6 +642,35 @@ class MessageTest(unittest.TestCase): m = unittest_pb2.TestAllTypes() m.repeated_nested_message.sort() + def testHasFieldOnRepeatedField(self): + """Using HasField on a repeated field should raise an exception. + """ + m = unittest_pb2.TestAllTypes() + with self.assertRaises(ValueError) as _: + m.HasField('repeated_int32') + + +class ValidTypeNamesTest(basetest.TestCase): + + def assertImportFromName(self, msg, base_name): + # Parse to extra 'some.name' as a string. + tp_name = str(type(msg)).split("'")[1] + valid_names = ('Repeated%sContainer' % base_name, + 'Repeated%sFieldContainer' % base_name) + self.assertTrue(any(tp_name.endswith(v) for v in valid_names), + '%r does end with any of %r' % (tp_name, valid_names)) + + parts = tp_name.split('.') + class_name = parts[-1] + module_name = '.'.join(parts[:-1]) + __import__(module_name, fromlist=[class_name]) + + def testTypeNamesCanBeImported(self): + # If import doesn't work, pickling won't work either. + pb = unittest_pb2.TestAllTypes() + self.assertImportFromName(pb.repeated_int32, 'Scalar') + self.assertImportFromName(pb.repeated_nested_message, 'Composite') + if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/missing_enum_values.proto b/python/google/protobuf/internal/missing_enum_values.proto new file mode 100644 index 00000000..c9ae58b6 --- /dev/null +++ b/python/google/protobuf/internal/missing_enum_values.proto @@ -0,0 +1,50 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package google.protobuf.python.internal; + +message TestEnumValues { + enum NestedEnum { + ZERO = 0; + ONE = 1; + } + optional NestedEnum optional_nested_enum = 1; + repeated NestedEnum repeated_nested_enum = 2; + repeated NestedEnum packed_nested_enum = 3 [packed = true]; +} + +message TestMissingEnumValues { + enum NestedEnum { + TWO = 2; + } + optional NestedEnum optional_nested_enum = 1; + repeated NestedEnum repeated_nested_enum = 2; + repeated NestedEnum packed_nested_enum = 3 [packed = true]; +} diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index 4bea57ac..9ee352d6 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -28,6 +28,10 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# Keep it Python2.5 compatible for GAE. +# +# Copyright 2007 Google Inc. All Rights Reserved. +# # This code is meant to work on Python 2.4 and above only. # # TODO(robinson): Helpers for verbose, common checks like seeing if a @@ -50,11 +54,16 @@ this file*. __author__ = 'robinson@google.com (Will Robinson)' -try: - from cStringIO import StringIO -except ImportError: - from StringIO import StringIO -import copy_reg +import sys +if sys.version_info[0] < 3: + try: + from cStringIO import StringIO as BytesIO + except ImportError: + from StringIO import StringIO as BytesIO + import copy_reg as copyreg +else: + from io import BytesIO + import copyreg import struct import weakref @@ -98,8 +107,8 @@ def InitMessage(descriptor, cls): _AddPropertiesForExtensions(descriptor, cls) _AddStaticMethods(cls) _AddMessageMethods(descriptor, cls) - _AddPrivateHelperMethods(cls) - copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) + _AddPrivateHelperMethods(descriptor, cls) + copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) # Stateless helpers for GeneratedProtocolMessageType below. @@ -176,7 +185,8 @@ def _AddSlots(message_descriptor, dictionary): '_is_present_in_parent', '_listener', '_listener_for_children', - '__weakref__'] + '__weakref__', + '_oneofs'] def _IsMessageSetExtension(field): @@ -272,7 +282,7 @@ def _DefaultValueConstructorForField(field): message._listener_for_children, field.message_type) return MakeRepeatedMessageDefault else: - type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) + type_checker = type_checkers.GetTypeChecker(field) def MakeRepeatedScalarDefault(message): return containers.RepeatedScalarFieldContainer( message._listener_for_children, type_checker) @@ -301,6 +311,10 @@ def _AddInitMethod(message_descriptor, cls): self._cached_byte_size = 0 self._cached_byte_size_dirty = len(kwargs) > 0 self._fields = {} + # Contains a mapping from oneof field descriptors to the descriptor + # of the currently set field in that oneof field. + self._oneofs = {} + # _unknown_fields is () when empty for efficiency, and will be turned into # a list if fields are added. self._unknown_fields = () @@ -440,7 +454,7 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls): """ proto_field_name = field.name property_name = _PropertyName(proto_field_name) - type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) + type_checker = type_checkers.GetTypeChecker(field) default_value = field.default_value valid_values = set() @@ -450,14 +464,21 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls): return self._fields.get(field, default_value) getter.__module__ = None getter.__doc__ = 'Getter for %s.' % proto_field_name - def setter(self, new_value): - type_checker.CheckValue(new_value) - self._fields[field] = new_value + def field_setter(self, new_value): + # pylint: disable=protected-access + self._fields[field] = type_checker.CheckValue(new_value) # Check _cached_byte_size_dirty inline to improve performance, since scalar # setters are called frequently. if not self._cached_byte_size_dirty: self._Modified() + if field.containing_oneof is not None: + def setter(self, new_value): + field_setter(self, new_value) + self._UpdateOneofState(field) + else: + setter = field_setter + setter.__module__ = None setter.__doc__ = 'Setter for %s.' % proto_field_name @@ -493,7 +514,10 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls): if field_value is None: # Construct a new object to represent this field. field_value = message_type._concrete_class() # use field.message_type? - field_value._SetListener(self._listener_for_children) + field_value._SetListener( + _OneofListener(self, field) + if field.containing_oneof is not None + else self._listener_for_children) # Atomically check if another thread has preempted us and, if not, swap # in the new object we just created. If someone has preempted us, we @@ -589,6 +613,9 @@ def _AddHasFieldMethod(message_descriptor, cls): for field in message_descriptor.fields: if field.label != _FieldDescriptor.LABEL_REPEATED: singular_fields[field.name] = field + # Fields inside oneofs are never repeated (enforced by the compiler). + for field in message_descriptor.oneofs: + singular_fields[field.name] = field def HasField(self, field_name): try: @@ -597,11 +624,18 @@ def _AddHasFieldMethod(message_descriptor, cls): raise ValueError( 'Protocol message has no singular "%s" field.' % field_name) - if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: - value = self._fields.get(field) - return value is not None and value._is_present_in_parent + if isinstance(field, descriptor_mod.OneofDescriptor): + try: + return HasField(self, self._oneofs[field].name) + except KeyError: + return False else: - return field in self._fields + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + value = self._fields.get(field) + return value is not None and value._is_present_in_parent + else: + return field in self._fields + cls.HasField = HasField @@ -611,7 +645,14 @@ def _AddClearFieldMethod(message_descriptor, cls): try: field = message_descriptor.fields_by_name[field_name] except KeyError: - raise ValueError('Protocol message has no "%s" field.' % field_name) + try: + field = message_descriptor.oneofs_by_name[field_name] + if field in self._oneofs: + field = self._oneofs[field] + else: + return + except KeyError: + raise ValueError('Protocol message has no "%s" field.' % field_name) if field in self._fields: # Note: If the field is a sub-message, its listener will still point @@ -619,6 +660,9 @@ def _AddClearFieldMethod(message_descriptor, cls): # will call _Modified() and invalidate our byte size. Big deal. del self._fields[field] + if self._oneofs.get(field.containing_oneof, None) is field: + del self._oneofs[field.containing_oneof] + # Always call _Modified() -- even if nothing was changed, this is # a mutating method, and thus calling it should cause the field to become # present in the parent message. @@ -773,7 +817,7 @@ def _AddSerializePartialToStringMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" def SerializePartialToString(self): - out = StringIO() + out = BytesIO() self._InternalSerialize(out.write) return out.getvalue() cls.SerializePartialToString = SerializePartialToString @@ -796,7 +840,8 @@ def _AddMergeFromStringMethod(message_descriptor, cls): # The only reason _InternalParse would return early is if it # encountered an end-group tag. raise message_mod.DecodeError('Unexpected end-group tag.') - except IndexError: + except (IndexError, TypeError): + # Now ord(buf[p:p+1]) == ord('') gets TypeError. raise message_mod.DecodeError('Truncated message.') except struct.error, e: raise message_mod.DecodeError(e) @@ -857,7 +902,7 @@ def _AddIsInitializedMethod(message_descriptor, cls): errors.extend(self.FindInitializationErrors()) return False - for field, value in self._fields.iteritems(): + for field, value in list(self._fields.items()): # dict can change size! if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: if field.label == _FieldDescriptor.LABEL_REPEATED: for element in value: @@ -953,6 +998,24 @@ def _AddMergeFromMethod(cls): cls.MergeFrom = MergeFrom +def _AddWhichOneofMethod(message_descriptor, cls): + def WhichOneof(self, oneof_name): + """Returns the name of the currently set field inside a oneof, or None.""" + try: + field = message_descriptor.oneofs_by_name[oneof_name] + except KeyError: + raise ValueError( + 'Protocol message has no oneof "%s" field.' % oneof_name) + + nested_field = self._oneofs.get(field, None) + if nested_field is not None and self.HasField(nested_field.name): + return nested_field.name + else: + return None + + cls.WhichOneof = WhichOneof + + def _AddMessageMethods(message_descriptor, cls): """Adds implementations of all Message methods to cls.""" _AddListFieldsMethod(message_descriptor, cls) @@ -972,9 +1035,9 @@ def _AddMessageMethods(message_descriptor, cls): _AddMergeFromStringMethod(message_descriptor, cls) _AddIsInitializedMethod(message_descriptor, cls) _AddMergeFromMethod(cls) + _AddWhichOneofMethod(message_descriptor, cls) - -def _AddPrivateHelperMethods(cls): +def _AddPrivateHelperMethods(message_descriptor, cls): """Adds implementation of private helper methods to cls.""" def Modified(self): @@ -992,8 +1055,20 @@ def _AddPrivateHelperMethods(cls): self._is_present_in_parent = True self._listener.Modified() + def _UpdateOneofState(self, field): + """Sets field as the active field in its containing oneof. + + Will also delete currently active field in the oneof, if it is different + from the argument. Does not mark the message as modified. + """ + other_field = self._oneofs.setdefault(field.containing_oneof, field) + if other_field is not field: + del self._fields[other_field] + self._oneofs[field.containing_oneof] = field + cls._Modified = Modified cls.SetInParent = Modified + cls._UpdateOneofState = _UpdateOneofState class _Listener(object): @@ -1042,6 +1117,27 @@ class _Listener(object): pass +class _OneofListener(_Listener): + """Special listener implementation for setting composite oneof fields.""" + + def __init__(self, parent_message, field): + """Args: + parent_message: The message whose _Modified() method we should call when + we receive Modified() messages. + field: The descriptor of the field being set in the parent message. + """ + super(_OneofListener, self).__init__(parent_message) + self._field = field + + def Modified(self): + """Also updates the state of the containing oneof in the parent message.""" + try: + self._parent_message_weakref._UpdateOneofState(self._field) + super(_OneofListener, self).Modified() + except ReferenceError: + pass + + # TODO(robinson): Move elsewhere? This file is getting pretty ridiculous... # TODO(robinson): Unify error handling of "unknown extension" crap. # TODO(robinson): Support iteritems()-style iteration over all @@ -1133,9 +1229,10 @@ class _ExtensionDict(object): # It's slightly wasteful to lookup the type checker each time, # but we expect this to be a vanishingly uncommon case anyway. type_checker = type_checkers.GetTypeChecker( - extension_handle.cpp_type, extension_handle.type) - type_checker.CheckValue(value) - self._extended_message._fields[extension_handle] = value + extension_handle) + # pylint: disable=protected-access + self._extended_message._fields[extension_handle] = ( + type_checker.CheckValue(value)) self._extended_message._Modified() def _FindExtensionByName(self, name): diff --git a/python/google/protobuf/internal/reflection_cpp2_generated_test.py b/python/google/protobuf/internal/reflection_cpp2_generated_test.py new file mode 100755 index 00000000..d7fce5fa --- /dev/null +++ b/python/google/protobuf/internal/reflection_cpp2_generated_test.py @@ -0,0 +1,94 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Unittest for reflection.py, which tests the generated C++ implementation.""" + +__author__ = 'jasonh@google.com (Jason Hsueh)' + +import os +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp' +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION'] = '2' + +from google.apputils import basetest +from google.protobuf.internal import api_implementation +from google.protobuf.internal import more_extensions_dynamic_pb2 +from google.protobuf.internal import more_extensions_pb2 +from google.protobuf.internal.reflection_test import * + + +class ReflectionCppTest(basetest.TestCase): + def testImplementationSetting(self): + self.assertEqual('cpp', api_implementation.Type()) + self.assertEqual(2, api_implementation.Version()) + + def testExtensionOfGeneratedTypeInDynamicFile(self): + """Tests that a file built dynamically can extend a generated C++ type. + + The C++ implementation uses a DescriptorPool that has the generated + DescriptorPool as an underlay. Typically, a type can only find + extensions in its own pool. With the python C-extension, the generated C++ + extendee may be available, but not the extension. This tests that the + C-extension implements the correct special handling to make such extensions + available. + """ + pb1 = more_extensions_pb2.ExtendedMessage() + # Test that basic accessors work. + self.assertFalse( + pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_int32_extension)) + self.assertFalse( + pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_message_extension)) + pb1.Extensions[more_extensions_dynamic_pb2.dynamic_int32_extension] = 17 + pb1.Extensions[more_extensions_dynamic_pb2.dynamic_message_extension].a = 24 + self.assertTrue( + pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_int32_extension)) + self.assertTrue( + pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_message_extension)) + + # Now serialize the data and parse to a new message. + pb2 = more_extensions_pb2.ExtendedMessage() + pb2.MergeFromString(pb1.SerializeToString()) + + self.assertTrue( + pb2.HasExtension(more_extensions_dynamic_pb2.dynamic_int32_extension)) + self.assertTrue( + pb2.HasExtension(more_extensions_dynamic_pb2.dynamic_message_extension)) + self.assertEqual( + 17, pb2.Extensions[more_extensions_dynamic_pb2.dynamic_int32_extension]) + self.assertEqual( + 24, + pb2.Extensions[more_extensions_dynamic_pb2.dynamic_message_extension].a) + + + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/internal/reflection_cpp_generated_test.py b/python/google/protobuf/internal/reflection_cpp_generated_test.py deleted file mode 100755 index 2a0a5124..00000000 --- a/python/google/protobuf/internal/reflection_cpp_generated_test.py +++ /dev/null @@ -1,91 +0,0 @@ -#! /usr/bin/python -# -*- coding: utf-8 -*- -# -# Protocol Buffers - Google's data interchange format -# Copyright 2008 Google Inc. All rights reserved. -# http://code.google.com/p/protobuf/ -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above -# copyright notice, this list of conditions and the following disclaimer -# in the documentation and/or other materials provided with the -# distribution. -# * Neither the name of Google Inc. nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -"""Unittest for reflection.py, which tests the generated C++ implementation.""" - -__author__ = 'jasonh@google.com (Jason Hsueh)' - -import os -os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp' - -import unittest -from google.protobuf.internal import api_implementation -from google.protobuf.internal import more_extensions_dynamic_pb2 -from google.protobuf.internal import more_extensions_pb2 -from google.protobuf.internal.reflection_test import * - - -class ReflectionCppTest(unittest.TestCase): - def testImplementationSetting(self): - self.assertEqual('cpp', api_implementation.Type()) - - def testExtensionOfGeneratedTypeInDynamicFile(self): - """Tests that a file built dynamically can extend a generated C++ type. - - The C++ implementation uses a DescriptorPool that has the generated - DescriptorPool as an underlay. Typically, a type can only find - extensions in its own pool. With the python C-extension, the generated C++ - extendee may be available, but not the extension. This tests that the - C-extension implements the correct special handling to make such extensions - available. - """ - pb1 = more_extensions_pb2.ExtendedMessage() - # Test that basic accessors work. - self.assertFalse( - pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_int32_extension)) - self.assertFalse( - pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_message_extension)) - pb1.Extensions[more_extensions_dynamic_pb2.dynamic_int32_extension] = 17 - pb1.Extensions[more_extensions_dynamic_pb2.dynamic_message_extension].a = 24 - self.assertTrue( - pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_int32_extension)) - self.assertTrue( - pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_message_extension)) - - # Now serialize the data and parse to a new message. - pb2 = more_extensions_pb2.ExtendedMessage() - pb2.MergeFromString(pb1.SerializeToString()) - - self.assertTrue( - pb2.HasExtension(more_extensions_dynamic_pb2.dynamic_int32_extension)) - self.assertTrue( - pb2.HasExtension(more_extensions_dynamic_pb2.dynamic_message_extension)) - self.assertEqual( - 17, pb2.Extensions[more_extensions_dynamic_pb2.dynamic_int32_extension]) - self.assertEqual( - 24, - pb2.Extensions[more_extensions_dynamic_pb2.dynamic_message_extension].a) - - -if __name__ == '__main__': - unittest.main() diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py index ed286461..b3c414c7 100755 --- a/python/google/protobuf/internal/reflection_test.py +++ b/python/google/protobuf/internal/reflection_test.py @@ -37,11 +37,12 @@ pure-Python protocol compiler. __author__ = 'robinson@google.com (Will Robinson)' +import copy import gc import operator import struct -import unittest +from google.apputils import basetest from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_pb2 @@ -49,6 +50,7 @@ from google.protobuf import descriptor_pb2 from google.protobuf import descriptor from google.protobuf import message from google.protobuf import reflection +from google.protobuf import text_format from google.protobuf.internal import api_implementation from google.protobuf.internal import more_extensions_pb2 from google.protobuf.internal import more_messages_pb2 @@ -102,7 +104,7 @@ class _MiniDecoder(object): return self._pos == len(self._bytes) -class ReflectionTest(unittest.TestCase): +class ReflectionTest(basetest.TestCase): def assertListsEqual(self, values, others): self.assertEqual(len(values), len(others)) @@ -533,7 +535,7 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(0.0, proto.optional_double) self.assertEqual(False, proto.optional_bool) self.assertEqual('', proto.optional_string) - self.assertEqual('', proto.optional_bytes) + self.assertEqual(b'', proto.optional_bytes) self.assertEqual(41, proto.default_int32) self.assertEqual(42, proto.default_int64) @@ -549,7 +551,7 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(52e3, proto.default_double) self.assertEqual(True, proto.default_bool) self.assertEqual('hello', proto.default_string) - self.assertEqual('world', proto.default_bytes) + self.assertEqual(b'world', proto.default_bytes) self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum) self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum) self.assertEqual(unittest_import_pb2.IMPORT_BAR, @@ -566,6 +568,17 @@ class ReflectionTest(unittest.TestCase): proto = unittest_pb2.TestAllTypes() self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field') + def testClearRemovesChildren(self): + # Make sure there aren't any implementation bugs that are only partially + # clearing the message (which can happen in the more complex C++ + # implementation which has parallel message lists). + proto = unittest_pb2.TestRequiredForeign() + for i in range(10): + proto.repeated_message.add() + proto2 = unittest_pb2.TestRequiredForeign() + proto.CopyFrom(proto2) + self.assertRaises(IndexError, lambda: proto.repeated_message[5]) + def testDisallowedAssignments(self): # It's illegal to assign values directly to repeated fields # or to nonrepeated composite fields. Ensure that this fails. @@ -594,6 +607,30 @@ class ReflectionTest(unittest.TestCase): self.assertRaises(TypeError, setattr, proto, 'optional_string', 10) self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10) + def testIntegerTypes(self): + def TestGetAndDeserialize(field_name, value, expected_type): + proto = unittest_pb2.TestAllTypes() + setattr(proto, field_name, value) + self.assertTrue(isinstance(getattr(proto, field_name), expected_type)) + proto2 = unittest_pb2.TestAllTypes() + proto2.ParseFromString(proto.SerializeToString()) + self.assertTrue(isinstance(getattr(proto2, field_name), expected_type)) + + TestGetAndDeserialize('optional_int32', 1, int) + TestGetAndDeserialize('optional_int32', 1 << 30, int) + TestGetAndDeserialize('optional_uint32', 1 << 30, int) + if struct.calcsize('L') == 4: + # Python only has signed ints, so 32-bit python can't fit an uint32 + # in an int. + TestGetAndDeserialize('optional_uint32', 1 << 31, long) + else: + # 64-bit python can fit uint32 inside an int + TestGetAndDeserialize('optional_uint32', 1 << 31, int) + TestGetAndDeserialize('optional_int64', 1 << 30, long) + TestGetAndDeserialize('optional_int64', 1 << 60, long) + TestGetAndDeserialize('optional_uint64', 1 << 30, long) + TestGetAndDeserialize('optional_uint64', 1 << 60, long) + def testSingleScalarBoundsChecking(self): def TestMinAndMaxIntegers(field_name, expected_min, expected_max): pb = unittest_pb2.TestAllTypes() @@ -613,29 +650,6 @@ class ReflectionTest(unittest.TestCase): pb.optional_nested_enum = 1 self.assertEqual(1, pb.optional_nested_enum) - # Invalid enum values. - pb.optional_nested_enum = 0 - self.assertEqual(0, pb.optional_nested_enum) - - bytes_size_before = pb.ByteSize() - - pb.optional_nested_enum = 4 - self.assertEqual(4, pb.optional_nested_enum) - - pb.optional_nested_enum = 0 - self.assertEqual(0, pb.optional_nested_enum) - - # Make sure that setting the same enum field doesn't just add unknown - # fields (but overwrites them). - self.assertEqual(bytes_size_before, pb.ByteSize()) - - # Is the invalid value preserved after serialization? - serialized = pb.SerializeToString() - pb2 = unittest_pb2.TestAllTypes() - pb2.ParseFromString(serialized) - self.assertEqual(0, pb2.optional_nested_enum) - self.assertEqual(pb, pb2) - def testRepeatedScalarTypeSafety(self): proto = unittest_pb2.TestAllTypes() self.assertRaises(TypeError, proto.repeated_int32.append, 1.1) @@ -749,9 +763,9 @@ class ReflectionTest(unittest.TestCase): unittest_pb2.ForeignEnum.items()) proto = unittest_pb2.TestAllTypes() - self.assertEqual(['FOO', 'BAR', 'BAZ'], proto.NestedEnum.keys()) - self.assertEqual([1, 2, 3], proto.NestedEnum.values()) - self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3)], + self.assertEqual(['FOO', 'BAR', 'BAZ', 'NEG'], proto.NestedEnum.keys()) + self.assertEqual([1, 2, 3, -1], proto.NestedEnum.values()) + self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)], proto.NestedEnum.items()) def testRepeatedScalars(self): @@ -1155,6 +1169,14 @@ class ReflectionTest(unittest.TestCase): self.assertTrue(required is not extendee_proto.Extensions[extension]) self.assertTrue(not extendee_proto.HasExtension(extension)) + def testRegisteredExtensions(self): + self.assertTrue('protobuf_unittest.optional_int32_extension' in + unittest_pb2.TestAllExtensions._extensions_by_name) + self.assertTrue(1 in unittest_pb2.TestAllExtensions._extensions_by_number) + # Make sure extensions haven't been registered into types that shouldn't + # have any. + self.assertEquals(0, len(unittest_pb2.TestAllTypes._extensions_by_name)) + # If message A directly contains message B, and # a.HasField('b') is currently False, then mutating any # extension in B should change a.HasField('b') to True @@ -1451,6 +1473,19 @@ class ReflectionTest(unittest.TestCase): proto2 = unittest_pb2.TestAllExtensions() self.assertRaises(TypeError, proto1.CopyFrom, proto2) + def testDeepCopy(self): + proto1 = unittest_pb2.TestAllTypes() + proto1.optional_int32 = 1 + proto2 = copy.deepcopy(proto1) + self.assertEqual(1, proto2.optional_int32) + + proto1.repeated_int32.append(2) + proto1.repeated_int32.append(3) + container = copy.deepcopy(proto1.repeated_int32) + self.assertEqual([2, 3], container) + + # TODO(anuraag): Implement deepcopy for repeated composite / extension dict + def testClear(self): proto = unittest_pb2.TestAllTypes() # C++ implementation does not support lazy fields right now so leave it @@ -1496,11 +1531,23 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(6, foreign.c) nested.bb = 15 foreign.c = 16 - self.assertTrue(not proto.HasField('optional_nested_message')) + self.assertFalse(proto.HasField('optional_nested_message')) self.assertEqual(0, proto.optional_nested_message.bb) - self.assertTrue(not proto.HasField('optional_foreign_message')) + self.assertFalse(proto.HasField('optional_foreign_message')) self.assertEqual(0, proto.optional_foreign_message.c) + def testOneOf(self): + proto = unittest_pb2.TestAllTypes() + proto.oneof_uint32 = 10 + proto.oneof_nested_message.bb = 11 + self.assertEqual(11, proto.oneof_nested_message.bb) + self.assertFalse(proto.HasField('oneof_uint32')) + nested = proto.oneof_nested_message + proto.oneof_string = 'abc' + self.assertEqual('abc', proto.oneof_string) + self.assertEqual(11, nested.bb) + self.assertFalse(proto.HasField('oneof_nested_message')) + def assertInitialized(self, proto): self.assertTrue(proto.IsInitialized()) # Neither method should raise an exception. @@ -1571,6 +1618,40 @@ class ReflectionTest(unittest.TestCase): self.assertFalse(proto.IsInitialized(errors)) self.assertEqual(errors, ['a', 'b', 'c']) + @basetest.unittest.skipIf( + api_implementation.Type() != 'cpp' or api_implementation.Version() != 2, + 'Errors are only available from the most recent C++ implementation.') + def testFileDescriptorErrors(self): + file_name = 'test_file_descriptor_errors.proto' + package_name = 'test_file_descriptor_errors.proto' + file_descriptor_proto = descriptor_pb2.FileDescriptorProto() + file_descriptor_proto.name = file_name + file_descriptor_proto.package = package_name + m1 = file_descriptor_proto.message_type.add() + m1.name = 'msg1' + # Compiles the proto into the C++ descriptor pool + descriptor.FileDescriptor( + file_name, + package_name, + serialized_pb=file_descriptor_proto.SerializeToString()) + # Add a FileDescriptorProto that has duplicate symbols + another_file_name = 'another_test_file_descriptor_errors.proto' + file_descriptor_proto.name = another_file_name + m2 = file_descriptor_proto.message_type.add() + m2.name = 'msg2' + with self.assertRaises(TypeError) as cm: + descriptor.FileDescriptor( + another_file_name, + package_name, + serialized_pb=file_descriptor_proto.SerializeToString()) + self.assertTrue(hasattr(cm, 'exception'), '%s not raised' % + getattr(cm.expected, '__name__', cm.expected)) + self.assertIn('test_file_descriptor_errors.proto', str(cm.exception)) + # Error message will say something about this definition being a + # duplicate, though we don't check the message exactly to avoid a + # dependency on the C++ logging code. + self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception)) + def testStringUTF8Encoding(self): proto = unittest_pb2.TestAllTypes() @@ -1588,17 +1669,15 @@ class ReflectionTest(unittest.TestCase): proto.optional_string = str('Testing') self.assertEqual(proto.optional_string, unicode('Testing')) - if api_implementation.Type() == 'python': - # Values of type 'str' are also accepted as long as they can be - # encoded in UTF-8. - self.assertEqual(type(proto.optional_string), str) - # Try to assign a 'str' value which contains bytes that aren't 7-bit ASCII. self.assertRaises(ValueError, - setattr, proto, 'optional_string', str('a\x80a')) - # Assign a 'str' object which contains a UTF-8 encoded string. - self.assertRaises(ValueError, - setattr, proto, 'optional_string', 'Тест') + setattr, proto, 'optional_string', b'a\x80a') + if str is bytes: # PY2 + # Assign a 'str' object which contains a UTF-8 encoded string. + self.assertRaises(ValueError, + setattr, proto, 'optional_string', 'Тест') + else: + proto.optional_string = 'Тест' # No exception thrown. proto.optional_string = 'abc' @@ -1621,7 +1700,8 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(proto.ByteSize(), len(serialized)) raw = unittest_mset_pb2.RawMessageSet() - raw.MergeFromString(serialized) + bytes_read = raw.MergeFromString(serialized) + self.assertEqual(len(serialized), bytes_read) message2 = unittest_mset_pb2.TestMessageSetExtension2() @@ -1632,7 +1712,8 @@ class ReflectionTest(unittest.TestCase): # Check the actual bytes on the wire. self.assertTrue( raw.item[0].message.endswith(test_utf8_bytes)) - message2.MergeFromString(raw.item[0].message) + bytes_read = message2.MergeFromString(raw.item[0].message) + self.assertEqual(len(raw.item[0].message), bytes_read) self.assertEqual(type(message2.str), unicode) self.assertEqual(message2.str, test_utf8) @@ -1643,17 +1724,22 @@ class ReflectionTest(unittest.TestCase): # MergeFromString and thus has no way to throw the exception. # # The pure Python API always returns objects of type 'unicode' (UTF-8 - # encoded), or 'str' (in 7 bit ASCII). - bytes = raw.item[0].message.replace( - test_utf8_bytes, len(test_utf8_bytes) * '\xff') + # encoded), or 'bytes' (in 7 bit ASCII). + badbytes = raw.item[0].message.replace( + test_utf8_bytes, len(test_utf8_bytes) * b'\xff') unicode_decode_failed = False try: - message2.MergeFromString(bytes) - except UnicodeDecodeError as e: + message2.MergeFromString(badbytes) + except UnicodeDecodeError: unicode_decode_failed = True string_field = message2.str - self.assertTrue(unicode_decode_failed or type(string_field) == str) + self.assertTrue(unicode_decode_failed or type(string_field) is bytes) + + def testBytesInTextFormat(self): + proto = unittest_pb2.TestAllTypes(optional_bytes=b'\x00\x7f\x80\xff') + self.assertEqual(u'optional_bytes: "\\000\\177\\200\\377"\n', + unicode(proto)) def testEmptyNestedMessage(self): proto = unittest_pb2.TestAllTypes() @@ -1667,16 +1753,19 @@ class ReflectionTest(unittest.TestCase): self.assertTrue(proto.HasField('optional_nested_message')) proto = unittest_pb2.TestAllTypes() - proto.optional_nested_message.MergeFromString('') + bytes_read = proto.optional_nested_message.MergeFromString(b'') + self.assertEqual(0, bytes_read) self.assertTrue(proto.HasField('optional_nested_message')) proto = unittest_pb2.TestAllTypes() - proto.optional_nested_message.ParseFromString('') + proto.optional_nested_message.ParseFromString(b'') self.assertTrue(proto.HasField('optional_nested_message')) serialized = proto.SerializeToString() proto2 = unittest_pb2.TestAllTypes() - proto2.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto2.MergeFromString(serialized)) self.assertTrue(proto2.HasField('optional_nested_message')) def testSetInParent(self): @@ -1690,7 +1779,7 @@ class ReflectionTest(unittest.TestCase): # into separate TestCase classes. -class TestAllTypesEqualityTest(unittest.TestCase): +class TestAllTypesEqualityTest(basetest.TestCase): def setUp(self): self.first_proto = unittest_pb2.TestAllTypes() @@ -1706,7 +1795,7 @@ class TestAllTypesEqualityTest(unittest.TestCase): self.assertEqual(self.first_proto, self.second_proto) -class FullProtosEqualityTest(unittest.TestCase): +class FullProtosEqualityTest(basetest.TestCase): """Equality tests using completely-full protos as a starting point.""" @@ -1792,7 +1881,7 @@ class FullProtosEqualityTest(unittest.TestCase): self.assertEqual(self.first_proto, self.second_proto) -class ExtensionEqualityTest(unittest.TestCase): +class ExtensionEqualityTest(basetest.TestCase): def testExtensionEquality(self): first_proto = unittest_pb2.TestAllExtensions() @@ -1825,7 +1914,7 @@ class ExtensionEqualityTest(unittest.TestCase): self.assertEqual(first_proto, second_proto) -class MutualRecursionEqualityTest(unittest.TestCase): +class MutualRecursionEqualityTest(basetest.TestCase): def testEqualityWithMutualRecursion(self): first_proto = unittest_pb2.TestMutualRecursionA() @@ -1837,7 +1926,7 @@ class MutualRecursionEqualityTest(unittest.TestCase): self.assertEqual(first_proto, second_proto) -class ByteSizeTest(unittest.TestCase): +class ByteSizeTest(basetest.TestCase): def setUp(self): self.proto = unittest_pb2.TestAllTypes() @@ -2133,14 +2222,16 @@ class ByteSizeTest(unittest.TestCase): # * Handling of empty submessages (with and without "has" # bits set). -class SerializationTest(unittest.TestCase): +class SerializationTest(basetest.TestCase): def testSerializeEmtpyMessage(self): first_proto = unittest_pb2.TestAllTypes() second_proto = unittest_pb2.TestAllTypes() serialized = first_proto.SerializeToString() self.assertEqual(first_proto.ByteSize(), len(serialized)) - second_proto.MergeFromString(serialized) + self.assertEqual( + len(serialized), + second_proto.MergeFromString(serialized)) self.assertEqual(first_proto, second_proto) def testSerializeAllFields(self): @@ -2149,7 +2240,9 @@ class SerializationTest(unittest.TestCase): test_util.SetAllFields(first_proto) serialized = first_proto.SerializeToString() self.assertEqual(first_proto.ByteSize(), len(serialized)) - second_proto.MergeFromString(serialized) + self.assertEqual( + len(serialized), + second_proto.MergeFromString(serialized)) self.assertEqual(first_proto, second_proto) def testSerializeAllExtensions(self): @@ -2157,7 +2250,19 @@ class SerializationTest(unittest.TestCase): second_proto = unittest_pb2.TestAllExtensions() test_util.SetAllExtensions(first_proto) serialized = first_proto.SerializeToString() - second_proto.MergeFromString(serialized) + self.assertEqual( + len(serialized), + second_proto.MergeFromString(serialized)) + self.assertEqual(first_proto, second_proto) + + def testSerializeWithOptionalGroup(self): + first_proto = unittest_pb2.TestAllTypes() + second_proto = unittest_pb2.TestAllTypes() + first_proto.optionalgroup.a = 242 + serialized = first_proto.SerializeToString() + self.assertEqual( + len(serialized), + second_proto.MergeFromString(serialized)) self.assertEqual(first_proto, second_proto) def testSerializeNegativeValues(self): @@ -2249,7 +2354,9 @@ class SerializationTest(unittest.TestCase): second_proto.optional_int32 = 100 second_proto.optional_nested_message.bb = 999 - second_proto.MergeFromString(serialized) + bytes_parsed = second_proto.MergeFromString(serialized) + self.assertEqual(len(serialized), bytes_parsed) + # Ensure that we append to repeated fields. self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string)) # Ensure that we overwrite nonrepeatd scalars. @@ -2274,20 +2381,28 @@ class SerializationTest(unittest.TestCase): raw = unittest_mset_pb2.RawMessageSet() self.assertEqual(False, raw.DESCRIPTOR.GetOptions().message_set_wire_format) - raw.MergeFromString(serialized) + self.assertEqual( + len(serialized), + raw.MergeFromString(serialized)) self.assertEqual(2, len(raw.item)) message1 = unittest_mset_pb2.TestMessageSetExtension1() - message1.MergeFromString(raw.item[0].message) + self.assertEqual( + len(raw.item[0].message), + message1.MergeFromString(raw.item[0].message)) self.assertEqual(123, message1.i) message2 = unittest_mset_pb2.TestMessageSetExtension2() - message2.MergeFromString(raw.item[1].message) + self.assertEqual( + len(raw.item[1].message), + message2.MergeFromString(raw.item[1].message)) self.assertEqual('foo', message2.str) # Deserialize using the MessageSet wire format. proto2 = unittest_mset_pb2.TestMessageSet() - proto2.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto2.MergeFromString(serialized)) self.assertEqual(123, proto2.Extensions[extension1].i) self.assertEqual('foo', proto2.Extensions[extension2].str) @@ -2327,7 +2442,9 @@ class SerializationTest(unittest.TestCase): # Parse message using the message set wire format. proto = unittest_mset_pb2.TestMessageSet() - proto.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto.MergeFromString(serialized)) # Check that the message parsed well. extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 @@ -2345,7 +2462,9 @@ class SerializationTest(unittest.TestCase): proto2 = unittest_pb2.TestEmptyMessage() # Parsing this message should succeed. - proto2.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto2.MergeFromString(serialized)) # Now test with a int64 field set. proto = unittest_pb2.TestAllTypes() @@ -2355,7 +2474,9 @@ class SerializationTest(unittest.TestCase): # unknown. proto2 = unittest_pb2.TestEmptyMessage() # Parsing this message should succeed. - proto2.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto2.MergeFromString(serialized)) def _CheckRaises(self, exc_class, callable_obj, exception): """This method checks if the excpetion type and message are as expected.""" @@ -2406,11 +2527,15 @@ class SerializationTest(unittest.TestCase): partial = proto.SerializePartialToString() proto2 = unittest_pb2.TestRequired() - proto2.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto2.MergeFromString(serialized)) self.assertEqual(1, proto2.a) self.assertEqual(2, proto2.b) self.assertEqual(3, proto2.c) - proto2.ParseFromString(partial) + self.assertEqual( + len(partial), + proto2.MergeFromString(partial)) self.assertEqual(1, proto2.a) self.assertEqual(2, proto2.b) self.assertEqual(3, proto2.c) @@ -2478,7 +2603,9 @@ class SerializationTest(unittest.TestCase): second_proto.packed_double.extend([1.0, 2.0]) second_proto.packed_sint32.append(4) - second_proto.MergeFromString(serialized) + self.assertEqual( + len(serialized), + second_proto.MergeFromString(serialized)) self.assertEqual([3, 1, 2], second_proto.packed_int32) self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double) self.assertEqual([4], second_proto.packed_sint32) @@ -2511,7 +2638,10 @@ class SerializationTest(unittest.TestCase): unpacked = unittest_pb2.TestUnpackedTypes() test_util.SetAllUnpackedFields(unpacked) packed = unittest_pb2.TestPackedTypes() - packed.MergeFromString(unpacked.SerializeToString()) + serialized = unpacked.SerializeToString() + self.assertEqual( + len(serialized), + packed.MergeFromString(serialized)) expected = unittest_pb2.TestPackedTypes() test_util.SetAllPackedFields(expected) self.assertEqual(expected, packed) @@ -2520,7 +2650,10 @@ class SerializationTest(unittest.TestCase): packed = unittest_pb2.TestPackedTypes() test_util.SetAllPackedFields(packed) unpacked = unittest_pb2.TestUnpackedTypes() - unpacked.MergeFromString(packed.SerializeToString()) + serialized = packed.SerializeToString() + self.assertEqual( + len(serialized), + unpacked.MergeFromString(serialized)) expected = unittest_pb2.TestUnpackedTypes() test_util.SetAllUnpackedFields(expected) self.assertEqual(expected, unpacked) @@ -2572,7 +2705,7 @@ class SerializationTest(unittest.TestCase): optional_int32=1, optional_string='foo', optional_bool=True, - optional_bytes='bar', + optional_bytes=b'bar', optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1), optional_foreign_message=unittest_pb2.ForeignMessage(c=1), optional_nested_enum=unittest_pb2.TestAllTypes.FOO, @@ -2590,7 +2723,7 @@ class SerializationTest(unittest.TestCase): self.assertEqual(1, proto.optional_int32) self.assertEqual('foo', proto.optional_string) self.assertEqual(True, proto.optional_bool) - self.assertEqual('bar', proto.optional_bytes) + self.assertEqual(b'bar', proto.optional_bytes) self.assertEqual(1, proto.optional_nested_message.bb) self.assertEqual(1, proto.optional_foreign_message.c) self.assertEqual(unittest_pb2.TestAllTypes.FOO, @@ -2640,7 +2773,7 @@ class SerializationTest(unittest.TestCase): self.assertEqual(3, proto.repeated_int32[2]) -class OptionsTest(unittest.TestCase): +class OptionsTest(basetest.TestCase): def testMessageOptions(self): proto = unittest_mset_pb2.TestMessageSet() @@ -2667,5 +2800,135 @@ class OptionsTest(unittest.TestCase): +class ClassAPITest(basetest.TestCase): + + def testMakeClassWithNestedDescriptor(self): + leaf_desc = descriptor.Descriptor('leaf', 'package.parent.child.leaf', '', + containing_type=None, fields=[], + nested_types=[], enum_types=[], + extensions=[]) + child_desc = descriptor.Descriptor('child', 'package.parent.child', '', + containing_type=None, fields=[], + nested_types=[leaf_desc], enum_types=[], + extensions=[]) + sibling_desc = descriptor.Descriptor('sibling', 'package.parent.sibling', + '', containing_type=None, fields=[], + nested_types=[], enum_types=[], + extensions=[]) + parent_desc = descriptor.Descriptor('parent', 'package.parent', '', + containing_type=None, fields=[], + nested_types=[child_desc, sibling_desc], + enum_types=[], extensions=[]) + message_class = reflection.MakeClass(parent_desc) + self.assertIn('child', message_class.__dict__) + self.assertIn('sibling', message_class.__dict__) + self.assertIn('leaf', message_class.child.__dict__) + + def _GetSerializedFileDescriptor(self, name): + """Get a serialized representation of a test FileDescriptorProto. + + Args: + name: All calls to this must use a unique message name, to avoid + collisions in the cpp descriptor pool. + Returns: + A string containing the serialized form of a test FileDescriptorProto. + """ + file_descriptor_str = ( + 'message_type {' + ' name: "' + name + '"' + ' field {' + ' name: "flat"' + ' number: 1' + ' label: LABEL_REPEATED' + ' type: TYPE_UINT32' + ' }' + ' field {' + ' name: "bar"' + ' number: 2' + ' label: LABEL_OPTIONAL' + ' type: TYPE_MESSAGE' + ' type_name: "Bar"' + ' }' + ' nested_type {' + ' name: "Bar"' + ' field {' + ' name: "baz"' + ' number: 3' + ' label: LABEL_OPTIONAL' + ' type: TYPE_MESSAGE' + ' type_name: "Baz"' + ' }' + ' nested_type {' + ' name: "Baz"' + ' enum_type {' + ' name: "deep_enum"' + ' value {' + ' name: "VALUE_A"' + ' number: 0' + ' }' + ' }' + ' field {' + ' name: "deep"' + ' number: 4' + ' label: LABEL_OPTIONAL' + ' type: TYPE_UINT32' + ' }' + ' }' + ' }' + '}') + file_descriptor = descriptor_pb2.FileDescriptorProto() + text_format.Merge(file_descriptor_str, file_descriptor) + return file_descriptor.SerializeToString() + + def testParsingFlatClassWithExplicitClassDeclaration(self): + """Test that the generated class can parse a flat message.""" + file_descriptor = descriptor_pb2.FileDescriptorProto() + file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('A')) + msg_descriptor = descriptor.MakeDescriptor( + file_descriptor.message_type[0]) + + class MessageClass(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + DESCRIPTOR = msg_descriptor + msg = MessageClass() + msg_str = ( + 'flat: 0 ' + 'flat: 1 ' + 'flat: 2 ') + text_format.Merge(msg_str, msg) + self.assertEqual(msg.flat, [0, 1, 2]) + + def testParsingFlatClass(self): + """Test that the generated class can parse a flat message.""" + file_descriptor = descriptor_pb2.FileDescriptorProto() + file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('B')) + msg_descriptor = descriptor.MakeDescriptor( + file_descriptor.message_type[0]) + msg_class = reflection.MakeClass(msg_descriptor) + msg = msg_class() + msg_str = ( + 'flat: 0 ' + 'flat: 1 ' + 'flat: 2 ') + text_format.Merge(msg_str, msg) + self.assertEqual(msg.flat, [0, 1, 2]) + + def testParsingNestedClass(self): + """Test that the generated class can parse a nested message.""" + file_descriptor = descriptor_pb2.FileDescriptorProto() + file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C')) + msg_descriptor = descriptor.MakeDescriptor( + file_descriptor.message_type[0]) + msg_class = reflection.MakeClass(msg_descriptor) + msg = msg_class() + msg_str = ( + 'bar {' + ' baz {' + ' deep: 4' + ' }' + '}') + text_format.Merge(msg_str, msg) + self.assertEqual(msg.bar.baz.deep, 4) + if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/service_reflection_test.py b/python/google/protobuf/internal/service_reflection_test.py index e04f8252..ef0981d9 100755 --- a/python/google/protobuf/internal/service_reflection_test.py +++ b/python/google/protobuf/internal/service_reflection_test.py @@ -34,13 +34,13 @@ __author__ = 'petar@google.com (Petar Petrov)' -import unittest +from google.apputils import basetest from google.protobuf import unittest_pb2 from google.protobuf import service_reflection from google.protobuf import service -class FooUnitTest(unittest.TestCase): +class FooUnitTest(basetest.TestCase): def testService(self): class MockRpcChannel(service.RpcChannel): @@ -133,4 +133,4 @@ class FooUnitTest(unittest.TestCase): if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/symbol_database_test.py b/python/google/protobuf/internal/symbol_database_test.py new file mode 100644 index 00000000..80bc8d6e --- /dev/null +++ b/python/google/protobuf/internal/symbol_database_test.py @@ -0,0 +1,120 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for google.protobuf.symbol_database.""" + +from google.apputils import basetest +from google.protobuf import unittest_pb2 +from google.protobuf import symbol_database + + +class SymbolDatabaseTest(basetest.TestCase): + + def _Database(self): + db = symbol_database.SymbolDatabase() + # Register representative types from unittest_pb2. + db.RegisterFileDescriptor(unittest_pb2.DESCRIPTOR) + db.RegisterMessage(unittest_pb2.TestAllTypes) + db.RegisterMessage(unittest_pb2.TestAllTypes.NestedMessage) + db.RegisterMessage(unittest_pb2.TestAllTypes.OptionalGroup) + db.RegisterMessage(unittest_pb2.TestAllTypes.RepeatedGroup) + db.RegisterEnumDescriptor(unittest_pb2.ForeignEnum.DESCRIPTOR) + db.RegisterEnumDescriptor(unittest_pb2.TestAllTypes.NestedEnum.DESCRIPTOR) + return db + + def testGetPrototype(self): + instance = self._Database().GetPrototype( + unittest_pb2.TestAllTypes.DESCRIPTOR) + self.assertTrue(instance is unittest_pb2.TestAllTypes) + + def testGetMessages(self): + messages = self._Database().GetMessages( + ['google/protobuf/unittest.proto']) + self.assertTrue( + unittest_pb2.TestAllTypes is + messages['protobuf_unittest.TestAllTypes']) + + def testGetSymbol(self): + self.assertEquals( + unittest_pb2.TestAllTypes, self._Database().GetSymbol( + 'protobuf_unittest.TestAllTypes')) + self.assertEquals( + unittest_pb2.TestAllTypes.NestedMessage, self._Database().GetSymbol( + 'protobuf_unittest.TestAllTypes.NestedMessage')) + self.assertEquals( + unittest_pb2.TestAllTypes.OptionalGroup, self._Database().GetSymbol( + 'protobuf_unittest.TestAllTypes.OptionalGroup')) + self.assertEquals( + unittest_pb2.TestAllTypes.RepeatedGroup, self._Database().GetSymbol( + 'protobuf_unittest.TestAllTypes.RepeatedGroup')) + + def testEnums(self): + # Check registration of types in the pool. + self.assertEquals( + 'protobuf_unittest.ForeignEnum', + self._Database().pool.FindEnumTypeByName( + 'protobuf_unittest.ForeignEnum').full_name) + self.assertEquals( + 'protobuf_unittest.TestAllTypes.NestedEnum', + self._Database().pool.FindEnumTypeByName( + 'protobuf_unittest.TestAllTypes.NestedEnum').full_name) + + def testFindMessageTypeByName(self): + self.assertEquals( + 'protobuf_unittest.TestAllTypes', + self._Database().pool.FindMessageTypeByName( + 'protobuf_unittest.TestAllTypes').full_name) + self.assertEquals( + 'protobuf_unittest.TestAllTypes.NestedMessage', + self._Database().pool.FindMessageTypeByName( + 'protobuf_unittest.TestAllTypes.NestedMessage').full_name) + + def testFindFindContainingSymbol(self): + # Lookup based on either enum or message. + self.assertEquals( + 'google/protobuf/unittest.proto', + self._Database().pool.FindFileContainingSymbol( + 'protobuf_unittest.TestAllTypes.NestedEnum').name) + self.assertEquals( + 'google/protobuf/unittest.proto', + self._Database().pool.FindFileContainingSymbol( + 'protobuf_unittest.TestAllTypes').name) + + def testFindFileByName(self): + self.assertEquals( + 'google/protobuf/unittest.proto', + self._Database().pool.FindFileByName( + 'google/protobuf/unittest.proto').name) + + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py index be8ae7be..350d1c6d 100755 --- a/python/google/protobuf/internal/test_util.py +++ b/python/google/protobuf/internal/test_util.py @@ -66,14 +66,8 @@ def SetAllNonLazyFields(message): message.optional_float = 111 message.optional_double = 112 message.optional_bool = True - # TODO(robinson): Firmly spec out and test how - # protos interact with unicode. One specific example: - # what happens if we change the literal below to - # u'115'? What *should* happen? Still some discussion - # to finish with Kenton about bytes vs. strings - # and forcing everything to be utf8. :-/ - message.optional_string = '115' - message.optional_bytes = '116' + message.optional_string = u'115' + message.optional_bytes = b'116' message.optionalgroup.a = 117 message.optional_nested_message.bb = 118 @@ -85,8 +79,8 @@ def SetAllNonLazyFields(message): message.optional_foreign_enum = unittest_pb2.FOREIGN_BAZ message.optional_import_enum = unittest_import_pb2.IMPORT_BAZ - message.optional_string_piece = '124' - message.optional_cord = '125' + message.optional_string_piece = u'124' + message.optional_cord = u'125' # # Repeated fields. @@ -105,8 +99,8 @@ def SetAllNonLazyFields(message): message.repeated_float.append(211) message.repeated_double.append(212) message.repeated_bool.append(True) - message.repeated_string.append('215') - message.repeated_bytes.append('216') + message.repeated_string.append(u'215') + message.repeated_bytes.append(b'216') message.repeatedgroup.add().a = 217 message.repeated_nested_message.add().bb = 218 @@ -118,8 +112,8 @@ def SetAllNonLazyFields(message): message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAR) message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAR) - message.repeated_string_piece.append('224') - message.repeated_cord.append('225') + message.repeated_string_piece.append(u'224') + message.repeated_cord.append(u'225') # Add a second one of each field. message.repeated_int32.append(301) @@ -135,8 +129,8 @@ def SetAllNonLazyFields(message): message.repeated_float.append(311) message.repeated_double.append(312) message.repeated_bool.append(False) - message.repeated_string.append('315') - message.repeated_bytes.append('316') + message.repeated_string.append(u'315') + message.repeated_bytes.append(b'316') message.repeatedgroup.add().a = 317 message.repeated_nested_message.add().bb = 318 @@ -148,8 +142,8 @@ def SetAllNonLazyFields(message): message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAZ) message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAZ) - message.repeated_string_piece.append('324') - message.repeated_cord.append('325') + message.repeated_string_piece.append(u'324') + message.repeated_cord.append(u'325') # # Fields that have defaults. @@ -169,7 +163,7 @@ def SetAllNonLazyFields(message): message.default_double = 412 message.default_bool = False message.default_string = '415' - message.default_bytes = '416' + message.default_bytes = b'416' message.default_nested_enum = unittest_pb2.TestAllTypes.FOO message.default_foreign_enum = unittest_pb2.FOREIGN_FOO @@ -178,6 +172,11 @@ def SetAllNonLazyFields(message): message.default_string_piece = '424' message.default_cord = '425' + message.oneof_uint32 = 601 + message.oneof_nested_message.bb = 602 + message.oneof_string = '603' + message.oneof_bytes = b'604' + def SetAllFields(message): SetAllNonLazyFields(message) @@ -212,8 +211,8 @@ def SetAllExtensions(message): extensions[pb2.optional_float_extension] = 111 extensions[pb2.optional_double_extension] = 112 extensions[pb2.optional_bool_extension] = True - extensions[pb2.optional_string_extension] = '115' - extensions[pb2.optional_bytes_extension] = '116' + extensions[pb2.optional_string_extension] = u'115' + extensions[pb2.optional_bytes_extension] = b'116' extensions[pb2.optionalgroup_extension].a = 117 extensions[pb2.optional_nested_message_extension].bb = 118 @@ -227,8 +226,8 @@ def SetAllExtensions(message): extensions[pb2.optional_foreign_enum_extension] = pb2.FOREIGN_BAZ extensions[pb2.optional_import_enum_extension] = import_pb2.IMPORT_BAZ - extensions[pb2.optional_string_piece_extension] = '124' - extensions[pb2.optional_cord_extension] = '125' + extensions[pb2.optional_string_piece_extension] = u'124' + extensions[pb2.optional_cord_extension] = u'125' # # Repeated fields. @@ -247,8 +246,8 @@ def SetAllExtensions(message): extensions[pb2.repeated_float_extension].append(211) extensions[pb2.repeated_double_extension].append(212) extensions[pb2.repeated_bool_extension].append(True) - extensions[pb2.repeated_string_extension].append('215') - extensions[pb2.repeated_bytes_extension].append('216') + extensions[pb2.repeated_string_extension].append(u'215') + extensions[pb2.repeated_bytes_extension].append(b'216') extensions[pb2.repeatedgroup_extension].add().a = 217 extensions[pb2.repeated_nested_message_extension].add().bb = 218 @@ -260,8 +259,8 @@ def SetAllExtensions(message): extensions[pb2.repeated_foreign_enum_extension].append(pb2.FOREIGN_BAR) extensions[pb2.repeated_import_enum_extension].append(import_pb2.IMPORT_BAR) - extensions[pb2.repeated_string_piece_extension].append('224') - extensions[pb2.repeated_cord_extension].append('225') + extensions[pb2.repeated_string_piece_extension].append(u'224') + extensions[pb2.repeated_cord_extension].append(u'225') # Append a second one of each field. extensions[pb2.repeated_int32_extension].append(301) @@ -277,8 +276,8 @@ def SetAllExtensions(message): extensions[pb2.repeated_float_extension].append(311) extensions[pb2.repeated_double_extension].append(312) extensions[pb2.repeated_bool_extension].append(False) - extensions[pb2.repeated_string_extension].append('315') - extensions[pb2.repeated_bytes_extension].append('316') + extensions[pb2.repeated_string_extension].append(u'315') + extensions[pb2.repeated_bytes_extension].append(b'316') extensions[pb2.repeatedgroup_extension].add().a = 317 extensions[pb2.repeated_nested_message_extension].add().bb = 318 @@ -290,8 +289,8 @@ def SetAllExtensions(message): extensions[pb2.repeated_foreign_enum_extension].append(pb2.FOREIGN_BAZ) extensions[pb2.repeated_import_enum_extension].append(import_pb2.IMPORT_BAZ) - extensions[pb2.repeated_string_piece_extension].append('324') - extensions[pb2.repeated_cord_extension].append('325') + extensions[pb2.repeated_string_piece_extension].append(u'324') + extensions[pb2.repeated_cord_extension].append(u'325') # # Fields with defaults. @@ -310,16 +309,21 @@ def SetAllExtensions(message): extensions[pb2.default_float_extension] = 411 extensions[pb2.default_double_extension] = 412 extensions[pb2.default_bool_extension] = False - extensions[pb2.default_string_extension] = '415' - extensions[pb2.default_bytes_extension] = '416' + extensions[pb2.default_string_extension] = u'415' + extensions[pb2.default_bytes_extension] = b'416' extensions[pb2.default_nested_enum_extension] = pb2.TestAllTypes.FOO extensions[pb2.default_foreign_enum_extension] = pb2.FOREIGN_FOO extensions[pb2.default_import_enum_extension] = import_pb2.IMPORT_FOO - extensions[pb2.default_string_piece_extension] = '424' + extensions[pb2.default_string_piece_extension] = u'424' extensions[pb2.default_cord_extension] = '425' + extensions[pb2.oneof_uint32_extension] = 601 + extensions[pb2.oneof_nested_message_extension].bb = 602 + extensions[pb2.oneof_string_extension] = u'603' + extensions[pb2.oneof_bytes_extension] = b'604' + def SetAllFieldsAndExtensions(message): """Sets every field and extension in the message to a unique value. @@ -358,7 +362,7 @@ def ExpectAllFieldsAndExtensionsInOrder(serialized): message.my_float = 1.0 expected_strings.append(message.SerializeToString()) message.Clear() - expected = ''.join(expected_strings) + expected = b''.join(expected_strings) if expected != serialized: raise ValueError('Expected %r, found %r' % (expected, serialized)) @@ -413,7 +417,7 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual(112, message.optional_double) test_case.assertEqual(True, message.optional_bool) test_case.assertEqual('115', message.optional_string) - test_case.assertEqual('116', message.optional_bytes) + test_case.assertEqual(b'116', message.optional_bytes) test_case.assertEqual(117, message.optionalgroup.a) test_case.assertEqual(118, message.optional_nested_message.bb) @@ -472,7 +476,7 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual(212, message.repeated_double[0]) test_case.assertEqual(True, message.repeated_bool[0]) test_case.assertEqual('215', message.repeated_string[0]) - test_case.assertEqual('216', message.repeated_bytes[0]) + test_case.assertEqual(b'216', message.repeated_bytes[0]) test_case.assertEqual(217, message.repeatedgroup[0].a) test_case.assertEqual(218, message.repeated_nested_message[0].bb) @@ -501,7 +505,7 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual(312, message.repeated_double[1]) test_case.assertEqual(False, message.repeated_bool[1]) test_case.assertEqual('315', message.repeated_string[1]) - test_case.assertEqual('316', message.repeated_bytes[1]) + test_case.assertEqual(b'316', message.repeated_bytes[1]) test_case.assertEqual(317, message.repeatedgroup[1].a) test_case.assertEqual(318, message.repeated_nested_message[1].bb) @@ -552,7 +556,7 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual(412, message.default_double) test_case.assertEqual(False, message.default_bool) test_case.assertEqual('415', message.default_string) - test_case.assertEqual('416', message.default_bytes) + test_case.assertEqual(b'416', message.default_bytes) test_case.assertEqual(unittest_pb2.TestAllTypes.FOO, message.default_nested_enum) @@ -561,6 +565,7 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual(unittest_import_pb2.IMPORT_FOO, message.default_import_enum) + def GoldenFile(filename): """Finds the given golden file and returns a file object representing it.""" @@ -574,9 +579,15 @@ def GoldenFile(filename): path = os.path.join(path, '..') raise RuntimeError( - 'Could not find golden files. This test must be run from within the ' - 'protobuf source package so that it can read test data files from the ' - 'C++ source tree.') + 'Could not find golden files. This test must be run from within the ' + 'protobuf source package so that it can read test data files from the ' + 'C++ source tree.') + + +def GoldenFileData(filename): + """Finds the given golden file and returns its contents.""" + with GoldenFile(filename) as f: + return f.read() def SetAllPackedFields(message): diff --git a/python/google/protobuf/internal/text_encoding_test.py b/python/google/protobuf/internal/text_encoding_test.py new file mode 100755 index 00000000..ba0e45d6 --- /dev/null +++ b/python/google/protobuf/internal/text_encoding_test.py @@ -0,0 +1,68 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for google.protobuf.text_encoding.""" + +from google.apputils import basetest +from google.protobuf import text_encoding + +TEST_VALUES = [ + ("foo\\rbar\\nbaz\\t", + "foo\\rbar\\nbaz\\t", + b"foo\rbar\nbaz\t"), + ("\\'full of \\\"sound\\\" and \\\"fury\\\"\\'", + "\\'full of \\\"sound\\\" and \\\"fury\\\"\\'", + b"'full of \"sound\" and \"fury\"'"), + ("signi\\\\fying\\\\ nothing\\\\", + "signi\\\\fying\\\\ nothing\\\\", + b"signi\\fying\\ nothing\\"), + ("\\010\\t\\n\\013\\014\\r", + "\x08\\t\\n\x0b\x0c\\r", + b"\010\011\012\013\014\015")] + + +class TextEncodingTestCase(basetest.TestCase): + def testCEscape(self): + for escaped, escaped_utf8, unescaped in TEST_VALUES: + self.assertEquals(escaped, + text_encoding.CEscape(unescaped, as_utf8=False)) + self.assertEquals(escaped_utf8, + text_encoding.CEscape(unescaped, as_utf8=True)) + + def testCUnescape(self): + for escaped, escaped_utf8, unescaped in TEST_VALUES: + self.assertEquals(unescaped, text_encoding.CUnescape(escaped)) + self.assertEquals(unescaped, text_encoding.CUnescape(escaped_utf8)) + + +if __name__ == "__main__": + basetest.main() diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py index 4b1b4f59..a85d4dd8 100755 --- a/python/google/protobuf/internal/text_format_test.py +++ b/python/google/protobuf/internal/text_format_test.py @@ -34,49 +34,71 @@ __author__ = 'kenton@google.com (Kenton Varda)' -import difflib import re -import unittest +from google.apputils import basetest from google.protobuf import text_format +from google.protobuf.internal import api_implementation from google.protobuf.internal import test_util from google.protobuf import unittest_pb2 from google.protobuf import unittest_mset_pb2 +class TextFormatTest(basetest.TestCase): -class TextFormatTest(unittest.TestCase): def ReadGolden(self, golden_filename): - f = test_util.GoldenFile(golden_filename) - golden_lines = f.readlines() - f.close() - return golden_lines + with test_util.GoldenFile(golden_filename) as f: + return (f.readlines() if str is bytes else # PY3 + [golden_line.decode('utf-8') for golden_line in f]) def CompareToGoldenFile(self, text, golden_filename): golden_lines = self.ReadGolden(golden_filename) - self.CompareToGoldenLines(text, golden_lines) + self.assertMultiLineEqual(text, ''.join(golden_lines)) def CompareToGoldenText(self, text, golden_text): - self.CompareToGoldenLines(text, golden_text.splitlines(1)) - - def CompareToGoldenLines(self, text, golden_lines): - actual_lines = text.splitlines(1) - self.assertEqual(golden_lines, actual_lines, - "Text doesn't match golden. Diff:\n" + - ''.join(difflib.ndiff(golden_lines, actual_lines))) + self.assertMultiLineEqual(text, golden_text) def testPrintAllFields(self): message = unittest_pb2.TestAllTypes() test_util.SetAllFields(message) self.CompareToGoldenFile( - self.RemoveRedundantZeros(text_format.MessageToString(message)), - 'text_format_unittest_data.txt') + self.RemoveRedundantZeros(text_format.MessageToString(message)), + 'text_format_unittest_data_oneof_implemented.txt') + + def testPrintInIndexOrder(self): + message = unittest_pb2.TestFieldOrderings() + message.my_string = '115' + message.my_int = 101 + message.my_float = 111 + self.CompareToGoldenText( + self.RemoveRedundantZeros(text_format.MessageToString( + message, use_index_order=True)), + 'my_string: \"115\"\nmy_int: 101\nmy_float: 111\n') + self.CompareToGoldenText( + self.RemoveRedundantZeros(text_format.MessageToString( + message)), 'my_int: 101\nmy_string: \"115\"\nmy_float: 111\n') def testPrintAllExtensions(self): message = unittest_pb2.TestAllExtensions() test_util.SetAllExtensions(message) self.CompareToGoldenFile( - self.RemoveRedundantZeros(text_format.MessageToString(message)), - 'text_format_unittest_extensions_data.txt') + self.RemoveRedundantZeros(text_format.MessageToString(message)), + 'text_format_unittest_extensions_data.txt') + + def testPrintAllFieldsPointy(self): + message = unittest_pb2.TestAllTypes() + test_util.SetAllFields(message) + self.CompareToGoldenFile( + self.RemoveRedundantZeros( + text_format.MessageToString(message, pointy_brackets=True)), + 'text_format_unittest_data_pointy_oneof_implemented.txt') + + def testPrintAllExtensionsPointy(self): + message = unittest_pb2.TestAllExtensions() + test_util.SetAllExtensions(message) + self.CompareToGoldenFile( + self.RemoveRedundantZeros(text_format.MessageToString( + message, pointy_brackets=True)), + 'text_format_unittest_extensions_data_pointy.txt') def testPrintMessageSet(self): message = unittest_mset_pb2.TestMessageSetContainer() @@ -84,37 +106,16 @@ class TextFormatTest(unittest.TestCase): ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension message.message_set.Extensions[ext1].i = 23 message.message_set.Extensions[ext2].str = 'foo' - self.CompareToGoldenText(text_format.MessageToString(message), - 'message_set {\n' - ' [protobuf_unittest.TestMessageSetExtension1] {\n' - ' i: 23\n' - ' }\n' - ' [protobuf_unittest.TestMessageSetExtension2] {\n' - ' str: \"foo\"\n' - ' }\n' - '}\n') - - def testPrintBadEnumValue(self): - message = unittest_pb2.TestAllTypes() - message.optional_nested_enum = 100 - message.optional_foreign_enum = 101 - message.optional_import_enum = 102 - self.CompareToGoldenText( - text_format.MessageToString(message), - 'optional_nested_enum: 100\n' - 'optional_foreign_enum: 101\n' - 'optional_import_enum: 102\n') - - def testPrintBadEnumValueExtensions(self): - message = unittest_pb2.TestAllExtensions() - message.Extensions[unittest_pb2.optional_nested_enum_extension] = 100 - message.Extensions[unittest_pb2.optional_foreign_enum_extension] = 101 - message.Extensions[unittest_pb2.optional_import_enum_extension] = 102 self.CompareToGoldenText( text_format.MessageToString(message), - '[protobuf_unittest.optional_nested_enum_extension]: 100\n' - '[protobuf_unittest.optional_foreign_enum_extension]: 101\n' - '[protobuf_unittest.optional_import_enum_extension]: 102\n') + 'message_set {\n' + ' [protobuf_unittest.TestMessageSetExtension1] {\n' + ' i: 23\n' + ' }\n' + ' [protobuf_unittest.TestMessageSetExtension2] {\n' + ' str: \"foo\"\n' + ' }\n' + '}\n') def testPrintExotic(self): message = unittest_pb2.TestAllTypes() @@ -126,20 +127,29 @@ class TextFormatTest(unittest.TestCase): message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'"') message.repeated_string.append(u'\u00fc\ua71f') self.CompareToGoldenText( - self.RemoveRedundantZeros(text_format.MessageToString(message)), - 'repeated_int64: -9223372036854775808\n' - 'repeated_uint64: 18446744073709551615\n' - 'repeated_double: 123.456\n' - 'repeated_double: 1.23e+22\n' - 'repeated_double: 1.23e-18\n' - 'repeated_string: ' - '"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""\n' - 'repeated_string: "\\303\\274\\352\\234\\237"\n') + self.RemoveRedundantZeros(text_format.MessageToString(message)), + 'repeated_int64: -9223372036854775808\n' + 'repeated_uint64: 18446744073709551615\n' + 'repeated_double: 123.456\n' + 'repeated_double: 1.23e+22\n' + 'repeated_double: 1.23e-18\n' + 'repeated_string:' + ' "\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""\n' + 'repeated_string: "\\303\\274\\352\\234\\237"\n') + + def testPrintExoticUnicodeSubclass(self): + class UnicodeSub(unicode): + pass + message = unittest_pb2.TestAllTypes() + message.repeated_string.append(UnicodeSub(u'\u00fc\ua71f')) + self.CompareToGoldenText( + text_format.MessageToString(message), + 'repeated_string: "\\303\\274\\352\\234\\237"\n') def testPrintNestedMessageAsOneLine(self): message = unittest_pb2.TestAllTypes() msg = message.repeated_nested_message.add() - msg.bb = 42; + msg.bb = 42 self.CompareToGoldenText( text_format.MessageToString(message, as_one_line=True), 'repeated_nested_message { bb: 42 }') @@ -190,16 +200,16 @@ class TextFormatTest(unittest.TestCase): message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'"') message.repeated_string.append(u'\u00fc\ua71f') self.CompareToGoldenText( - self.RemoveRedundantZeros( - text_format.MessageToString(message, as_one_line=True)), - 'repeated_int64: -9223372036854775808' - ' repeated_uint64: 18446744073709551615' - ' repeated_double: 123.456' - ' repeated_double: 1.23e+22' - ' repeated_double: 1.23e-18' - ' repeated_string: ' - '"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""' - ' repeated_string: "\\303\\274\\352\\234\\237"') + self.RemoveRedundantZeros( + text_format.MessageToString(message, as_one_line=True)), + 'repeated_int64: -9223372036854775808' + ' repeated_uint64: 18446744073709551615' + ' repeated_double: 123.456' + ' repeated_double: 1.23e+22' + ' repeated_double: 1.23e-18' + ' repeated_string: ' + '"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""' + ' repeated_string: "\\303\\274\\352\\234\\237"') def testRoundTripExoticAsOneLine(self): message = unittest_pb2.TestAllTypes() @@ -215,24 +225,60 @@ class TextFormatTest(unittest.TestCase): wire_text = text_format.MessageToString( message, as_one_line=True, as_utf8=False) parsed_message = unittest_pb2.TestAllTypes() - text_format.Merge(wire_text, parsed_message) + r = text_format.Parse(wire_text, parsed_message) + self.assertIs(r, parsed_message) self.assertEquals(message, parsed_message) # Test as_utf8 = True. wire_text = text_format.MessageToString( message, as_one_line=True, as_utf8=True) parsed_message = unittest_pb2.TestAllTypes() - text_format.Merge(wire_text, parsed_message) - self.assertEquals(message, parsed_message) + r = text_format.Parse(wire_text, parsed_message) + self.assertIs(r, parsed_message) + self.assertEquals(message, parsed_message, + '\n%s != %s' % (message, parsed_message)) def testPrintRawUtf8String(self): message = unittest_pb2.TestAllTypes() message.repeated_string.append(u'\u00fc\ua71f') - text = text_format.MessageToString(message, as_utf8 = True) + text = text_format.MessageToString(message, as_utf8=True) self.CompareToGoldenText(text, 'repeated_string: "\303\274\352\234\237"\n') parsed_message = unittest_pb2.TestAllTypes() - text_format.Merge(text, parsed_message) - self.assertEquals(message, parsed_message) + text_format.Parse(text, parsed_message) + self.assertEquals(message, parsed_message, + '\n%s != %s' % (message, parsed_message)) + + def testPrintFloatFormat(self): + # Check that float_format argument is passed to sub-message formatting. + message = unittest_pb2.NestedTestAllTypes() + # We use 1.25 as it is a round number in binary. The proto 32-bit float + # will not gain additional imprecise digits as a 64-bit Python float and + # show up in its str. 32-bit 1.2 is noisy when extended to 64-bit: + # >>> struct.unpack('f', struct.pack('f', 1.2))[0] + # 1.2000000476837158 + # >>> struct.unpack('f', struct.pack('f', 1.25))[0] + # 1.25 + message.payload.optional_float = 1.25 + # Check rounding at 15 significant digits + message.payload.optional_double = -.000003456789012345678 + # Check no decimal point. + message.payload.repeated_float.append(-5642) + # Check no trailing zeros. + message.payload.repeated_double.append(.000078900) + formatted_fields = ['optional_float: 1.25', + 'optional_double: -3.45678901234568e-6', + 'repeated_float: -5642', + 'repeated_double: 7.89e-5'] + text_message = text_format.MessageToString(message, float_format='.15g') + self.CompareToGoldenText( + self.RemoveRedundantZeros(text_message), + 'payload {{\n {}\n {}\n {}\n {}\n}}\n'.format(*formatted_fields)) + # as_one_line=True is a separate code branch where float_format is passed. + text_message = text_format.MessageToString(message, as_one_line=True, + float_format='.15g') + self.CompareToGoldenText( + self.RemoveRedundantZeros(text_message), + 'payload {{ {} {} {} {} }}'.format(*formatted_fields)) def testMessageToString(self): message = unittest_pb2.ForeignMessage() @@ -249,49 +295,50 @@ class TextFormatTest(unittest.TestCase): text = re.compile('\.0$', re.MULTILINE).sub('', text) return text - def testMergeGolden(self): + def testParseGolden(self): golden_text = '\n'.join(self.ReadGolden('text_format_unittest_data.txt')) parsed_message = unittest_pb2.TestAllTypes() - text_format.Merge(golden_text, parsed_message) + r = text_format.Parse(golden_text, parsed_message) + self.assertIs(r, parsed_message) message = unittest_pb2.TestAllTypes() test_util.SetAllFields(message) self.assertEquals(message, parsed_message) - def testMergeGoldenExtensions(self): + def testParseGoldenExtensions(self): golden_text = '\n'.join(self.ReadGolden( 'text_format_unittest_extensions_data.txt')) parsed_message = unittest_pb2.TestAllExtensions() - text_format.Merge(golden_text, parsed_message) + text_format.Parse(golden_text, parsed_message) message = unittest_pb2.TestAllExtensions() test_util.SetAllExtensions(message) self.assertEquals(message, parsed_message) - def testMergeAllFields(self): + def testParseAllFields(self): message = unittest_pb2.TestAllTypes() test_util.SetAllFields(message) ascii_text = text_format.MessageToString(message) parsed_message = unittest_pb2.TestAllTypes() - text_format.Merge(ascii_text, parsed_message) + text_format.Parse(ascii_text, parsed_message) self.assertEqual(message, parsed_message) test_util.ExpectAllFieldsSet(self, message) - def testMergeAllExtensions(self): + def testParseAllExtensions(self): message = unittest_pb2.TestAllExtensions() test_util.SetAllExtensions(message) ascii_text = text_format.MessageToString(message) parsed_message = unittest_pb2.TestAllExtensions() - text_format.Merge(ascii_text, parsed_message) + text_format.Parse(ascii_text, parsed_message) self.assertEqual(message, parsed_message) - def testMergeMessageSet(self): + def testParseMessageSet(self): message = unittest_pb2.TestAllTypes() text = ('repeated_uint64: 1\n' 'repeated_uint64: 2\n') - text_format.Merge(text, message) + text_format.Parse(text, message) self.assertEqual(1, message.repeated_uint64[0]) self.assertEqual(2, message.repeated_uint64[1]) @@ -304,13 +351,13 @@ class TextFormatTest(unittest.TestCase): ' str: \"foo\"\n' ' }\n' '}\n') - text_format.Merge(text, message) + text_format.Parse(text, message) ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension self.assertEquals(23, message.message_set.Extensions[ext1].i) self.assertEquals('foo', message.message_set.Extensions[ext2].str) - def testMergeExotic(self): + def testParseExotic(self): message = unittest_pb2.TestAllTypes() text = ('repeated_int64: -9223372036854775808\n' 'repeated_uint64: 18446744073709551615\n' @@ -323,7 +370,7 @@ class TextFormatTest(unittest.TestCase): 'repeated_string: "\\303\\274\\352\\234\\237"\n' 'repeated_string: "\\xc3\\xbc"\n' 'repeated_string: "\xc3\xbc"\n') - text_format.Merge(text, message) + text_format.Parse(text, message) self.assertEqual(-9223372036854775808, message.repeated_int64[0]) self.assertEqual(18446744073709551615, message.repeated_uint64[0]) @@ -336,100 +383,115 @@ class TextFormatTest(unittest.TestCase): self.assertEqual(u'\u00fc\ua71f', message.repeated_string[2]) self.assertEqual(u'\u00fc', message.repeated_string[3]) - def testMergeEmptyText(self): + def testParseTrailingCommas(self): + message = unittest_pb2.TestAllTypes() + text = ('repeated_int64: 100;\n' + 'repeated_int64: 200;\n' + 'repeated_int64: 300,\n' + 'repeated_string: "one",\n' + 'repeated_string: "two";\n') + text_format.Parse(text, message) + + self.assertEqual(100, message.repeated_int64[0]) + self.assertEqual(200, message.repeated_int64[1]) + self.assertEqual(300, message.repeated_int64[2]) + self.assertEqual(u'one', message.repeated_string[0]) + self.assertEqual(u'two', message.repeated_string[1]) + + def testParseEmptyText(self): message = unittest_pb2.TestAllTypes() text = '' - text_format.Merge(text, message) + text_format.Parse(text, message) self.assertEquals(unittest_pb2.TestAllTypes(), message) - def testMergeInvalidUtf8(self): + def testParseInvalidUtf8(self): message = unittest_pb2.TestAllTypes() text = 'repeated_string: "\\xc3\\xc3"' - self.assertRaises(text_format.ParseError, text_format.Merge, text, message) + self.assertRaises(text_format.ParseError, text_format.Parse, text, message) - def testMergeSingleWord(self): + def testParseSingleWord(self): message = unittest_pb2.TestAllTypes() text = 'foo' - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, ('1:1 : Message type "protobuf_unittest.TestAllTypes" has no field named ' '"foo".'), - text_format.Merge, text, message) + text_format.Parse, text, message) - def testMergeUnknownField(self): + def testParseUnknownField(self): message = unittest_pb2.TestAllTypes() text = 'unknown_field: 8\n' - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, ('1:1 : Message type "protobuf_unittest.TestAllTypes" has no field named ' '"unknown_field".'), - text_format.Merge, text, message) + text_format.Parse, text, message) - def testMergeBadExtension(self): + def testParseBadExtension(self): message = unittest_pb2.TestAllExtensions() text = '[unknown_extension]: 8\n' - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, '1:2 : Extension "unknown_extension" not registered.', - text_format.Merge, text, message) + text_format.Parse, text, message) message = unittest_pb2.TestAllTypes() - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, ('1:2 : Message type "protobuf_unittest.TestAllTypes" does not have ' 'extensions.'), - text_format.Merge, text, message) + text_format.Parse, text, message) - def testMergeGroupNotClosed(self): + def testParseGroupNotClosed(self): message = unittest_pb2.TestAllTypes() text = 'RepeatedGroup: <' - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, '1:16 : Expected ">".', - text_format.Merge, text, message) + text_format.Parse, text, message) text = 'RepeatedGroup: {' - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, '1:16 : Expected "}".', - text_format.Merge, text, message) + text_format.Parse, text, message) - def testMergeEmptyGroup(self): + def testParseEmptyGroup(self): message = unittest_pb2.TestAllTypes() text = 'OptionalGroup: {}' - text_format.Merge(text, message) + text_format.Parse(text, message) self.assertTrue(message.HasField('optionalgroup')) message.Clear() message = unittest_pb2.TestAllTypes() text = 'OptionalGroup: <>' - text_format.Merge(text, message) + text_format.Parse(text, message) self.assertTrue(message.HasField('optionalgroup')) - def testMergeBadEnumValue(self): + def testParseBadEnumValue(self): message = unittest_pb2.TestAllTypes() text = 'optional_nested_enum: BARR' - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" ' 'has no value named BARR.'), - text_format.Merge, text, message) + text_format.Parse, text, message) message = unittest_pb2.TestAllTypes() text = 'optional_nested_enum: 100' - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" ' 'has no value with number 100.'), - text_format.Merge, text, message) + text_format.Parse, text, message) - def testMergeBadIntValue(self): + def testParseBadIntValue(self): message = unittest_pb2.TestAllTypes() text = 'optional_int32: bork' - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, ('1:17 : Couldn\'t parse integer: bork'), - text_format.Merge, text, message) + text_format.Parse, text, message) - def testMergeStringFieldUnescape(self): + def testParseStringFieldUnescape(self): message = unittest_pb2.TestAllTypes() text = r'''repeated_string: "\xf\x62" repeated_string: "\\xf\\x62" @@ -437,7 +499,7 @@ class TextFormatTest(unittest.TestCase): repeated_string: "\\\\xf\\\\x62" repeated_string: "\\\\\xf\\\\\x62" repeated_string: "\x5cx20"''' - text_format.Merge(text, message) + text_format.Parse(text, message) SLASH = '\\' self.assertEqual('\x0fb', message.repeated_string[0]) @@ -449,27 +511,84 @@ class TextFormatTest(unittest.TestCase): message.repeated_string[4]) self.assertEqual(SLASH + 'x20', message.repeated_string[5]) - def assertRaisesWithMessage(self, e_class, e, func, *args, **kwargs): - """Same as assertRaises, but also compares the exception message.""" - if hasattr(e_class, '__name__'): - exc_name = e_class.__name__ - else: - exc_name = str(e_class) + def testMergeRepeatedScalars(self): + message = unittest_pb2.TestAllTypes() + text = ('optional_int32: 42 ' + 'optional_int32: 67') + r = text_format.Merge(text, message) + self.assertIs(r, message) + self.assertEqual(67, message.optional_int32) + + def testParseRepeatedScalars(self): + message = unittest_pb2.TestAllTypes() + text = ('optional_int32: 42 ' + 'optional_int32: 67') + self.assertRaisesWithLiteralMatch( + text_format.ParseError, + ('1:36 : Message type "protobuf_unittest.TestAllTypes" should not ' + 'have multiple "optional_int32" fields.'), + text_format.Parse, text, message) + + def testMergeRepeatedNestedMessageScalars(self): + message = unittest_pb2.TestAllTypes() + text = ('optional_nested_message { bb: 1 } ' + 'optional_nested_message { bb: 2 }') + r = text_format.Merge(text, message) + self.assertTrue(r is message) + self.assertEqual(2, message.optional_nested_message.bb) + + def testParseRepeatedNestedMessageScalars(self): + message = unittest_pb2.TestAllTypes() + text = ('optional_nested_message { bb: 1 } ' + 'optional_nested_message { bb: 2 }') + self.assertRaisesWithLiteralMatch( + text_format.ParseError, + ('1:65 : Message type "protobuf_unittest.TestAllTypes.NestedMessage" ' + 'should not have multiple "bb" fields.'), + text_format.Parse, text, message) + + def testMergeRepeatedExtensionScalars(self): + message = unittest_pb2.TestAllExtensions() + text = ('[protobuf_unittest.optional_int32_extension]: 42 ' + '[protobuf_unittest.optional_int32_extension]: 67') + text_format.Merge(text, message) + self.assertEqual( + 67, + message.Extensions[unittest_pb2.optional_int32_extension]) + + def testParseRepeatedExtensionScalars(self): + message = unittest_pb2.TestAllExtensions() + text = ('[protobuf_unittest.optional_int32_extension]: 42 ' + '[protobuf_unittest.optional_int32_extension]: 67') + self.assertRaisesWithLiteralMatch( + text_format.ParseError, + ('1:96 : Message type "protobuf_unittest.TestAllExtensions" ' + 'should not have multiple ' + '"protobuf_unittest.optional_int32_extension" extensions.'), + text_format.Parse, text, message) + + def testParseLinesGolden(self): + opened = self.ReadGolden('text_format_unittest_data.txt') + parsed_message = unittest_pb2.TestAllTypes() + r = text_format.ParseLines(opened, parsed_message) + self.assertIs(r, parsed_message) + + message = unittest_pb2.TestAllTypes() + test_util.SetAllFields(message) + self.assertEquals(message, parsed_message) + + def testMergeLinesGolden(self): + opened = self.ReadGolden('text_format_unittest_data.txt') + parsed_message = unittest_pb2.TestAllTypes() + r = text_format.MergeLines(opened, parsed_message) + self.assertIs(r, parsed_message) - try: - func(*args, **kwargs) - except e_class as expr: - if str(expr) != e: - msg = '%s raised, but with wrong message: "%s" instead of "%s"' - raise self.failureException(msg % (exc_name, - str(expr).encode('string_escape'), - e.encode('string_escape'))) - return - else: - raise self.failureException('%s not raised' % exc_name) + message = unittest_pb2.TestAllTypes() + test_util.SetAllFields(message) + self.assertEqual(message, parsed_message) -class TokenizerTest(unittest.TestCase): +class TokenizerTest(basetest.TestCase): def testSimpleTokenCases(self): text = ('identifier1:"string1"\n \n\n' @@ -478,8 +597,8 @@ class TokenizerTest(unittest.TestCase): 'ID7 : "aa\\"bb"\n\n\n\n ID8: {A:inf B:-inf C:true D:false}\n' 'ID9: 22 ID10: -111111111111111111 ID11: -22\n' 'ID12: 2222222222222222222 ID13: 1.23456f ID14: 1.2e+2f ' - 'false_bool: 0 true_BOOL:t \n true_bool1: 1 false_BOOL1:f ' ) - tokenizer = text_format._Tokenizer(text) + 'false_bool: 0 true_BOOL:t \n true_bool1: 1 false_BOOL1:f ') + tokenizer = text_format._Tokenizer(text.splitlines()) methods = [(tokenizer.ConsumeIdentifier, 'identifier1'), ':', (tokenizer.ConsumeString, 'string1'), @@ -565,7 +684,7 @@ class TokenizerTest(unittest.TestCase): int64_max = (1 << 63) - 1 uint32_max = (1 << 32) - 1 text = '-1 %d %d' % (uint32_max + 1, int64_max + 1) - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint32) self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint64) self.assertEqual(-1, tokenizer.ConsumeInt32()) @@ -579,7 +698,7 @@ class TokenizerTest(unittest.TestCase): self.assertTrue(tokenizer.AtEnd()) text = '-0 -0 0 0' - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertEqual(0, tokenizer.ConsumeUint32()) self.assertEqual(0, tokenizer.ConsumeUint64()) self.assertEqual(0, tokenizer.ConsumeUint32()) @@ -588,30 +707,30 @@ class TokenizerTest(unittest.TestCase): def testConsumeByteString(self): text = '"string1\'' - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) text = 'string1"' - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) text = '\n"\\xt"' - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) text = '\n"\\"' - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) text = '\n"\\x"' - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) def testConsumeBool(self): text = 'not-a-bool' - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeBool) if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/type_checkers.py b/python/google/protobuf/internal/type_checkers.py index 2b3cd4de..8e1b3cc3 100755 --- a/python/google/protobuf/internal/type_checkers.py +++ b/python/google/protobuf/internal/type_checkers.py @@ -28,6 +28,10 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#PY25 compatible for GAE. +# +# Copyright 2008 Google Inc. All Rights Reserved. + """Provides type checking routines. This module defines type checking utilities in the forms of dictionaries: @@ -45,6 +49,9 @@ TYPE_TO_DESERIALIZE_METHOD: A dictionary with field types and deserialization __author__ = 'robinson@google.com (Will Robinson)' +import sys ##PY25 +if sys.version < '2.6': bytes = str ##PY25 +from google.protobuf.internal import api_implementation from google.protobuf.internal import decoder from google.protobuf.internal import encoder from google.protobuf.internal import wire_format @@ -53,21 +60,22 @@ from google.protobuf import descriptor _FieldDescriptor = descriptor.FieldDescriptor -def GetTypeChecker(cpp_type, field_type): +def GetTypeChecker(field): """Returns a type checker for a message field of the specified types. Args: - cpp_type: C++ type of the field (see descriptor.py). - field_type: Protocol message field type (see descriptor.py). + field: FieldDescriptor object for this field. Returns: An instance of TypeChecker which can be used to verify the types of values assigned to a field of the specified type. """ - if (cpp_type == _FieldDescriptor.CPPTYPE_STRING and - field_type == _FieldDescriptor.TYPE_STRING): + if (field.cpp_type == _FieldDescriptor.CPPTYPE_STRING and + field.type == _FieldDescriptor.TYPE_STRING): return UnicodeValueChecker() - return _VALUE_CHECKERS[cpp_type] + if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: + return EnumValueChecker(field.enum_type) + return _VALUE_CHECKERS[field.cpp_type] # None of the typecheckers below make any attempt to guard against people @@ -85,10 +93,15 @@ class TypeChecker(object): self._acceptable_types = acceptable_types def CheckValue(self, proposed_value): + """Type check the provided value and return it. + + The returned value might have been normalized to another type. + """ if not isinstance(proposed_value, self._acceptable_types): message = ('%.1024r has type %s, but expected one of: %s' % (proposed_value, type(proposed_value), self._acceptable_types)) raise TypeError(message) + return proposed_value # IntValueChecker and its subclasses perform integer type-checks @@ -104,28 +117,54 @@ class IntValueChecker(object): raise TypeError(message) if not self._MIN <= proposed_value <= self._MAX: raise ValueError('Value out of range: %d' % proposed_value) + # We force 32-bit values to int and 64-bit values to long to make + # alternate implementations where the distinction is more significant + # (e.g. the C++ implementation) simpler. + proposed_value = self._TYPE(proposed_value) + return proposed_value + + +class EnumValueChecker(object): + + """Checker used for enum fields. Performs type-check and range check.""" + + def __init__(self, enum_type): + self._enum_type = enum_type + + def CheckValue(self, proposed_value): + if not isinstance(proposed_value, (int, long)): + message = ('%.1024r has type %s, but expected one of: %s' % + (proposed_value, type(proposed_value), (int, long))) + raise TypeError(message) + if proposed_value not in self._enum_type.values_by_number: + raise ValueError('Unknown enum value: %d' % proposed_value) + return proposed_value class UnicodeValueChecker(object): - """Checker used for string fields.""" + """Checker used for string fields. + + Always returns a unicode value, even if the input is of type str. + """ def CheckValue(self, proposed_value): - if not isinstance(proposed_value, (str, unicode)): + if not isinstance(proposed_value, (bytes, unicode)): message = ('%.1024r has type %s, but expected one of: %s' % - (proposed_value, type(proposed_value), (str, unicode))) + (proposed_value, type(proposed_value), (bytes, unicode))) raise TypeError(message) - # If the value is of type 'str' make sure that it is in 7-bit ASCII + # If the value is of type 'bytes' make sure that it is in 7-bit ASCII # encoding. - if isinstance(proposed_value, str): + if isinstance(proposed_value, bytes): try: - unicode(proposed_value, 'ascii') + proposed_value = proposed_value.decode('ascii') except UnicodeDecodeError: - raise ValueError('%.1024r has type str, but isn\'t in 7-bit ASCII ' + raise ValueError('%.1024r has type bytes, but isn\'t in 7-bit ASCII ' 'encoding. Non-ASCII strings must be converted to ' 'unicode objects before being added.' % (proposed_value)) + return proposed_value class Int32ValueChecker(IntValueChecker): @@ -133,21 +172,25 @@ class Int32ValueChecker(IntValueChecker): # efficient. _MIN = -2147483648 _MAX = 2147483647 + _TYPE = int class Uint32ValueChecker(IntValueChecker): _MIN = 0 _MAX = (1 << 32) - 1 + _TYPE = int class Int64ValueChecker(IntValueChecker): _MIN = -(1 << 63) _MAX = (1 << 63) - 1 + _TYPE = long class Uint64ValueChecker(IntValueChecker): _MIN = 0 _MAX = (1 << 64) - 1 + _TYPE = long # Type-checkers for all scalar CPPTYPEs. @@ -161,8 +204,7 @@ _VALUE_CHECKERS = { _FieldDescriptor.CPPTYPE_FLOAT: TypeChecker( float, int, long), _FieldDescriptor.CPPTYPE_BOOL: TypeChecker(bool, int), - _FieldDescriptor.CPPTYPE_ENUM: Int32ValueChecker(), - _FieldDescriptor.CPPTYPE_STRING: TypeChecker(str), + _FieldDescriptor.CPPTYPE_STRING: TypeChecker(bytes), } diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py index 84984b40..8f3354c9 100755 --- a/python/google/protobuf/internal/unknown_fields_test.py +++ b/python/google/protobuf/internal/unknown_fields_test.py @@ -35,15 +35,16 @@ __author__ = 'bohdank@google.com (Bohdan Koval)' -import unittest +from google.apputils import basetest from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_pb2 from google.protobuf.internal import encoder +from google.protobuf.internal import missing_enum_values_pb2 from google.protobuf.internal import test_util from google.protobuf.internal import type_checkers -class UnknownFieldsTest(unittest.TestCase): +class UnknownFieldsTest(basetest.TestCase): def setUp(self): self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR @@ -58,12 +59,20 @@ class UnknownFieldsTest(unittest.TestCase): field_descriptor = self.descriptor.fields_by_name[name] wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] field_tag = encoder.TagBytes(field_descriptor.number, wire_type) + result_dict = {} for tag_bytes, value in self.unknown_fields: if tag_bytes == field_tag: decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes] - result_dict = {} decoder(value, 0, len(value), self.all_fields, result_dict) - return result_dict[field_descriptor] + return result_dict[field_descriptor] + + def testEnum(self): + value = self.GetField('optional_nested_enum') + self.assertEqual(self.all_fields.optional_nested_enum, value) + + def testRepeatedEnum(self): + value = self.GetField('repeated_nested_enum') + self.assertEqual(self.all_fields.repeated_nested_enum, value) def testVarint(self): value = self.GetField('optional_int32') @@ -166,5 +175,57 @@ class UnknownFieldsTest(unittest.TestCase): self.assertNotEqual(self.empty_message, message) +class UnknownFieldsTest(basetest.TestCase): + + def setUp(self): + self.descriptor = missing_enum_values_pb2.TestEnumValues.DESCRIPTOR + + self.message = missing_enum_values_pb2.TestEnumValues() + self.message.optional_nested_enum = ( + missing_enum_values_pb2.TestEnumValues.ZERO) + self.message.repeated_nested_enum.extend([ + missing_enum_values_pb2.TestEnumValues.ZERO, + missing_enum_values_pb2.TestEnumValues.ONE, + ]) + self.message.packed_nested_enum.extend([ + missing_enum_values_pb2.TestEnumValues.ZERO, + missing_enum_values_pb2.TestEnumValues.ONE, + ]) + self.message_data = self.message.SerializeToString() + self.missing_message = missing_enum_values_pb2.TestMissingEnumValues() + self.missing_message.ParseFromString(self.message_data) + self.unknown_fields = self.missing_message._unknown_fields + + def GetField(self, name): + field_descriptor = self.descriptor.fields_by_name[name] + wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] + field_tag = encoder.TagBytes(field_descriptor.number, wire_type) + result_dict = {} + for tag_bytes, value in self.unknown_fields: + if tag_bytes == field_tag: + decoder = missing_enum_values_pb2.TestEnumValues._decoders_by_tag[ + tag_bytes] + decoder(value, 0, len(value), self.message, result_dict) + return result_dict[field_descriptor] + + def testUnknownEnumValue(self): + self.assertFalse(self.missing_message.HasField('optional_nested_enum')) + value = self.GetField('optional_nested_enum') + self.assertEqual(self.message.optional_nested_enum, value) + + def testUnknownRepeatedEnumValue(self): + value = self.GetField('repeated_nested_enum') + self.assertEqual(self.message.repeated_nested_enum, value) + + def testUnknownPackedEnumValue(self): + value = self.GetField('packed_nested_enum') + self.assertEqual(self.message.packed_nested_enum, value) + + def testRoundTrip(self): + new_message = missing_enum_values_pb2.TestEnumValues() + new_message.ParseFromString(self.missing_message.SerializeToString()) + self.assertEqual(self.message, new_message) + + if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/wire_format_test.py b/python/google/protobuf/internal/wire_format_test.py index 76007786..9362c72d 100755 --- a/python/google/protobuf/internal/wire_format_test.py +++ b/python/google/protobuf/internal/wire_format_test.py @@ -34,12 +34,12 @@ __author__ = 'robinson@google.com (Will Robinson)' -import unittest +from google.apputils import basetest from google.protobuf import message from google.protobuf.internal import wire_format -class WireFormatTest(unittest.TestCase): +class WireFormatTest(basetest.TestCase): def testPackTag(self): field_number = 0xabc @@ -195,7 +195,7 @@ class WireFormatTest(unittest.TestCase): # Test UTF-8 string byte size calculation. # 1 byte for tag, 1 byte for length, 8 bytes for content. self.assertEqual(10, wire_format.StringByteSize( - 5, unicode('\xd0\xa2\xd0\xb5\xd1\x81\xd1\x82', 'utf-8'))) + 5, b'\xd0\xa2\xd0\xb5\xd1\x81\xd1\x82'.decode('utf-8'))) class MockMessage(object): def __init__(self, byte_size): @@ -250,4 +250,4 @@ class WireFormatTest(unittest.TestCase): if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/message.py b/python/google/protobuf/message.py index 6ec2f8be..37b0af14 100755 --- a/python/google/protobuf/message.py +++ b/python/google/protobuf/message.py @@ -177,7 +177,11 @@ class Message(object): raise NotImplementedError def ParseFromString(self, serialized): - """Like MergeFromString(), except we clear the object first.""" + """Parse serialized protocol buffer data into this message. + + Like MergeFromString(), except we clear the object first and + do not return the value that MergeFromString returns. + """ self.Clear() self.MergeFromString(serialized) diff --git a/python/google/protobuf/message_factory.py b/python/google/protobuf/message_factory.py index 36e2fef0..9004ffd9 100644 --- a/python/google/protobuf/message_factory.py +++ b/python/google/protobuf/message_factory.py @@ -28,10 +28,22 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -"""Provides a factory class for generating dynamic messages.""" +#PY25 compatible for GAE. +# +# Copyright 2012 Google Inc. All Rights Reserved. + +"""Provides a factory class for generating dynamic messages. + +The easiest way to use this class is if you have access to the FileDescriptor +protos containing the messages you want to create you can just do the following: + +message_classes = message_factory.GetMessages(iterable_of_file_descriptors) +my_proto_instance = message_classes['some.proto.package.MessageName']() +""" __author__ = 'matthewtoia@google.com (Matt Toia)' +import sys ##PY25 from google.protobuf import descriptor_database from google.protobuf import descriptor_pool from google.protobuf import message @@ -41,8 +53,12 @@ from google.protobuf import reflection class MessageFactory(object): """Factory for creating Proto2 messages from descriptors in a pool.""" - def __init__(self): + def __init__(self, pool=None): """Initializes a new factory.""" + self.pool = (pool or descriptor_pool.DescriptorPool( + descriptor_database.DescriptorDatabase())) + + # local cache of all classes built from protobuf descriptors self._classes = {} def GetPrototype(self, descriptor): @@ -57,21 +73,69 @@ class MessageFactory(object): Returns: A class describing the passed in descriptor. """ - if descriptor.full_name not in self._classes: + descriptor_name = descriptor.name + if sys.version_info[0] < 3: ##PY25 +##!PY25 if str is bytes: # PY2 + descriptor_name = descriptor.name.encode('ascii', 'ignore') result_class = reflection.GeneratedProtocolMessageType( - descriptor.name.encode('ascii', 'ignore'), + descriptor_name, (message.Message,), - {'DESCRIPTOR': descriptor}) + {'DESCRIPTOR': descriptor, '__module__': None}) + # If module not set, it wrongly points to the reflection.py module. self._classes[descriptor.full_name] = result_class for field in descriptor.fields: if field.message_type: self.GetPrototype(field.message_type) + for extension in result_class.DESCRIPTOR.extensions: + if extension.containing_type.full_name not in self._classes: + self.GetPrototype(extension.containing_type) + extended_class = self._classes[extension.containing_type.full_name] + extended_class.RegisterExtension(extension) return self._classes[descriptor.full_name] + def GetMessages(self, files): + """Gets all the messages from a specified file. + + This will find and resolve dependencies, failing if the descriptor + pool cannot satisfy them. + + Args: + files: The file names to extract messages from. + + Returns: + A dictionary mapping proto names to the message classes. This will include + any dependent messages as well as any messages defined in the same file as + a specified message. + """ + result = {} + for file_name in files: + file_desc = self.pool.FindFileByName(file_name) + for name, msg in file_desc.message_types_by_name.iteritems(): + if file_desc.package: + full_name = '.'.join([file_desc.package, name]) + else: + full_name = msg.name + result[full_name] = self.GetPrototype( + self.pool.FindMessageTypeByName(full_name)) + + # While the extension FieldDescriptors are created by the descriptor pool, + # the python classes created in the factory need them to be registered + # explicitly, which is done below. + # + # The call to RegisterExtension will specifically check if the + # extension was already registered on the object and either + # ignore the registration if the original was the same, or raise + # an error if they were different. + + for name, extension in file_desc.extensions_by_name.iteritems(): + if extension.containing_type.full_name not in self._classes: + self.GetPrototype(extension.containing_type) + extended_class = self._classes[extension.containing_type.full_name] + extended_class.RegisterExtension(extension) + return result + -_DB = descriptor_database.DescriptorDatabase() -_POOL = descriptor_pool.DescriptorPool(_DB) _FACTORY = MessageFactory() @@ -82,32 +146,10 @@ def GetMessages(file_protos): file_protos: A sequence of file protos to build messages out of. Returns: - A dictionary containing all the message types in the files mapping the - fully qualified name to a Message subclass for the descriptor. + A dictionary mapping proto names to the message classes. This will include + any dependent messages as well as any messages defined in the same file as + a specified message. """ - - result = {} for file_proto in file_protos: - _DB.Add(file_proto) - for file_proto in file_protos: - for desc in _GetAllDescriptors(file_proto.message_type, file_proto.package): - result[desc.full_name] = _FACTORY.GetPrototype(desc) - return result - - -def _GetAllDescriptors(desc_protos, package): - """Gets all levels of nested message types as a flattened list of descriptors. - - Args: - desc_protos: The descriptor protos to process. - package: The package where the protos are defined. - - Yields: - Each message descriptor for each nested type. - """ - - for desc_proto in desc_protos: - name = '.'.join((package, desc_proto.name)) - yield _POOL.FindMessageTypeByName(name) - for nested_desc in _GetAllDescriptors(desc_proto.nested_type, name): - yield nested_desc + _FACTORY.pool.Add(file_proto) + return _FACTORY.GetMessages([file_proto.name for file_proto in file_protos]) diff --git a/python/google/protobuf/pyext/README b/python/google/protobuf/pyext/README new file mode 100644 index 00000000..6d61cb45 --- /dev/null +++ b/python/google/protobuf/pyext/README @@ -0,0 +1,6 @@ +This is the 'v2' C++ implementation for python proto2. + +It is active when: + +PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp +PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION=2 diff --git a/python/google/protobuf/pyext/cpp_message.py b/python/google/protobuf/pyext/cpp_message.py new file mode 100644 index 00000000..ba87f8ea --- /dev/null +++ b/python/google/protobuf/pyext/cpp_message.py @@ -0,0 +1,61 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Protocol message implementation hooks for C++ implementation. + +Contains helper functions used to create protocol message classes from +Descriptor objects at runtime backed by the protocol buffer C++ API. +""" + +__author__ = 'tibell@google.com (Johan Tibell)' + +from google.protobuf.pyext import _message +from google.protobuf import message + + +def NewMessage(bases, message_descriptor, dictionary): + """Creates a new protocol message *class*.""" + new_bases = [] + for base in bases: + if base is message.Message: + # _message.Message must come before message.Message as it + # overrides methods in that class. + new_bases.append(_message.Message) + new_bases.append(base) + return tuple(new_bases) + + +def InitMessage(message_descriptor, cls): + """Constructs a new message instance (called before instance's __init__).""" + + def SubInit(self, **kwargs): + super(cls, self).__init__(message_descriptor, **kwargs) + cls.__init__ = SubInit + cls.AddDescriptors(message_descriptor) diff --git a/python/google/protobuf/pyext/descriptor.cc b/python/google/protobuf/pyext/descriptor.cc new file mode 100644 index 00000000..cbf42c00 --- /dev/null +++ b/python/google/protobuf/pyext/descriptor.cc @@ -0,0 +1,357 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: petar@google.com (Petar Petrov) + +#include +#include + +#include +#include +#include + +#define C(str) const_cast(str) + +#if PY_MAJOR_VERSION >= 3 + #define PyString_FromStringAndSize PyUnicode_FromStringAndSize + #define PyInt_FromLong PyLong_FromLong + #if PY_VERSION_HEX < 0x03030000 + #error "Python 3.0 - 3.2 are not supported." + #else + #define PyString_AsString(ob) \ + (PyUnicode_Check(ob)? PyUnicode_AsUTF8(ob): PyBytes_AS_STRING(ob)) + #endif +#endif + +namespace google { +namespace protobuf { +namespace python { + + +#ifndef PyVarObject_HEAD_INIT +#define PyVarObject_HEAD_INIT(type, size) PyObject_HEAD_INIT(type) size, +#endif +#ifndef Py_TYPE +#define Py_TYPE(ob) (((PyObject*)(ob))->ob_type) +#endif + + +static google::protobuf::DescriptorPool* g_descriptor_pool = NULL; + +namespace cfield_descriptor { + +static void Dealloc(CFieldDescriptor* self) { + Py_CLEAR(self->descriptor_field); + Py_TYPE(self)->tp_free(reinterpret_cast(self)); +} + +static PyObject* GetFullName(CFieldDescriptor* self, void *closure) { + return PyString_FromStringAndSize( + self->descriptor->full_name().c_str(), + self->descriptor->full_name().size()); +} + +static PyObject* GetName(CFieldDescriptor *self, void *closure) { + return PyString_FromStringAndSize( + self->descriptor->name().c_str(), + self->descriptor->name().size()); +} + +static PyObject* GetCppType(CFieldDescriptor *self, void *closure) { + return PyInt_FromLong(self->descriptor->cpp_type()); +} + +static PyObject* GetLabel(CFieldDescriptor *self, void *closure) { + return PyInt_FromLong(self->descriptor->label()); +} + +static PyObject* GetID(CFieldDescriptor *self, void *closure) { + return PyLong_FromVoidPtr(self); +} + +static PyGetSetDef Getters[] = { + { C("full_name"), (getter)GetFullName, NULL, "Full name", NULL}, + { C("name"), (getter)GetName, NULL, "last name", NULL}, + { C("cpp_type"), (getter)GetCppType, NULL, "C++ Type", NULL}, + { C("label"), (getter)GetLabel, NULL, "Label", NULL}, + { C("id"), (getter)GetID, NULL, "ID", NULL}, + {NULL} +}; + +} // namespace cfield_descriptor + +PyTypeObject CFieldDescriptor_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + C("google.protobuf.internal." + "_net_proto2___python." + "CFieldDescriptor"), // tp_name + sizeof(CFieldDescriptor), // tp_basicsize + 0, // tp_itemsize + (destructor)cfield_descriptor::Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + C("A Field Descriptor"), // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + 0, // tp_methods + 0, // tp_members + cfield_descriptor::Getters, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init + PyType_GenericAlloc, // tp_alloc + PyType_GenericNew, // tp_new + PyObject_Del, // tp_free +}; + +namespace cdescriptor_pool { + +static void Dealloc(CDescriptorPool* self) { + Py_TYPE(self)->tp_free(reinterpret_cast(self)); +} + +static PyObject* NewCDescriptor( + const google::protobuf::FieldDescriptor* field_descriptor) { + CFieldDescriptor* cfield_descriptor = PyObject_New( + CFieldDescriptor, &CFieldDescriptor_Type); + if (cfield_descriptor == NULL) { + return NULL; + } + cfield_descriptor->descriptor = field_descriptor; + cfield_descriptor->descriptor_field = NULL; + + return reinterpret_cast(cfield_descriptor); +} + +PyObject* FindFieldByName(CDescriptorPool* self, PyObject* name) { + const char* full_field_name = PyString_AsString(name); + if (full_field_name == NULL) { + return NULL; + } + + const google::protobuf::FieldDescriptor* field_descriptor = NULL; + + field_descriptor = self->pool->FindFieldByName(full_field_name); + + if (field_descriptor == NULL) { + PyErr_Format(PyExc_TypeError, "Couldn't find field %.200s", + full_field_name); + return NULL; + } + + return NewCDescriptor(field_descriptor); +} + +PyObject* FindExtensionByName(CDescriptorPool* self, PyObject* arg) { + const char* full_field_name = PyString_AsString(arg); + if (full_field_name == NULL) { + return NULL; + } + + const google::protobuf::FieldDescriptor* field_descriptor = + self->pool->FindExtensionByName(full_field_name); + if (field_descriptor == NULL) { + PyErr_Format(PyExc_TypeError, "Couldn't find field %.200s", + full_field_name); + return NULL; + } + + return NewCDescriptor(field_descriptor); +} + +static PyMethodDef Methods[] = { + { C("FindFieldByName"), + (PyCFunction)FindFieldByName, + METH_O, + C("Searches for a field descriptor by full name.") }, + { C("FindExtensionByName"), + (PyCFunction)FindExtensionByName, + METH_O, + C("Searches for extension descriptor by full name.") }, + {NULL} +}; + +} // namespace cdescriptor_pool + +PyTypeObject CDescriptorPool_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + C("google.protobuf.internal." + "_net_proto2___python." + "CFieldDescriptor"), // tp_name + sizeof(CDescriptorPool), // tp_basicsize + 0, // tp_itemsize + (destructor)cdescriptor_pool::Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + C("A Descriptor Pool"), // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + cdescriptor_pool::Methods, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init + PyType_GenericAlloc, // tp_alloc + PyType_GenericNew, // tp_new + PyObject_Del, // tp_free +}; + +google::protobuf::DescriptorPool* GetDescriptorPool() { + if (g_descriptor_pool == NULL) { + g_descriptor_pool = new google::protobuf::DescriptorPool( + google::protobuf::DescriptorPool::generated_pool()); + } + return g_descriptor_pool; +} + +PyObject* Python_NewCDescriptorPool(PyObject* ignored, PyObject* args) { + CDescriptorPool* cdescriptor_pool = PyObject_New( + CDescriptorPool, &CDescriptorPool_Type); + if (cdescriptor_pool == NULL) { + return NULL; + } + cdescriptor_pool->pool = GetDescriptorPool(); + return reinterpret_cast(cdescriptor_pool); +} + + +// Collects errors that occur during proto file building to allow them to be +// propagated in the python exception instead of only living in ERROR logs. +class BuildFileErrorCollector : public google::protobuf::DescriptorPool::ErrorCollector { + public: + BuildFileErrorCollector() : error_message(""), had_errors(false) {} + + void AddError(const string& filename, const string& element_name, + const Message* descriptor, ErrorLocation location, + const string& message) { + // Replicates the logging behavior that happens in the C++ implementation + // when an error collector is not passed in. + if (!had_errors) { + error_message += + ("Invalid proto descriptor for file \"" + filename + "\":\n"); + } + // As this only happens on failure and will result in the program not + // running at all, no effort is made to optimize this string manipulation. + error_message += (" " + element_name + ": " + message + "\n"); + } + + string error_message; + bool had_errors; +}; + +PyObject* Python_BuildFile(PyObject* ignored, PyObject* arg) { + char* message_type; + Py_ssize_t message_len; + + if (PyBytes_AsStringAndSize(arg, &message_type, &message_len) < 0) { + return NULL; + } + + google::protobuf::FileDescriptorProto file_proto; + if (!file_proto.ParseFromArray(message_type, message_len)) { + PyErr_SetString(PyExc_TypeError, "Couldn't parse file content!"); + return NULL; + } + + if (google::protobuf::DescriptorPool::generated_pool()->FindFileByName( + file_proto.name()) != NULL) { + Py_RETURN_NONE; + } + + BuildFileErrorCollector error_collector; + const google::protobuf::FileDescriptor* descriptor = + GetDescriptorPool()->BuildFileCollectingErrors(file_proto, + &error_collector); + if (descriptor == NULL) { + PyErr_Format(PyExc_TypeError, + "Couldn't build proto file into descriptor pool!\n%s", + error_collector.error_message.c_str()); + return NULL; + } + + Py_RETURN_NONE; +} + +bool InitDescriptor() { + CFieldDescriptor_Type.tp_new = PyType_GenericNew; + if (PyType_Ready(&CFieldDescriptor_Type) < 0) + return false; + + CDescriptorPool_Type.tp_new = PyType_GenericNew; + if (PyType_Ready(&CDescriptorPool_Type) < 0) + return false; + + return true; +} + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/python/google/protobuf/pyext/descriptor.h b/python/google/protobuf/pyext/descriptor.h new file mode 100644 index 00000000..d114425a --- /dev/null +++ b/python/google/protobuf/pyext/descriptor.h @@ -0,0 +1,96 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: petar@google.com (Petar Petrov) + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_H__ + +#include +#include + +#include + +#if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN) +typedef int Py_ssize_t; +#define PY_SSIZE_T_MAX INT_MAX +#define PY_SSIZE_T_MIN INT_MIN +#endif + +namespace google { +namespace protobuf { +namespace python { + +typedef struct CFieldDescriptor { + PyObject_HEAD + + // The proto2 descriptor that this object represents. + const google::protobuf::FieldDescriptor* descriptor; + + // Reference to the original field object in the Python DESCRIPTOR. + PyObject* descriptor_field; +} CFieldDescriptor; + +typedef struct { + PyObject_HEAD + + const google::protobuf::DescriptorPool* pool; +} CDescriptorPool; + +extern PyTypeObject CFieldDescriptor_Type; + +extern PyTypeObject CDescriptorPool_Type; + +namespace cdescriptor_pool { + +// Looks up a field by name. Returns a CDescriptor corresponding to +// the field on success, or NULL on failure. +// +// Returns a new reference. +PyObject* FindFieldByName(CDescriptorPool* self, PyObject* name); + +// Looks up an extension by name. Returns a CDescriptor corresponding +// to the field on success, or NULL on failure. +// +// Returns a new reference. +PyObject* FindExtensionByName(CDescriptorPool* self, PyObject* arg); + +} // namespace cdescriptor_pool + +PyObject* Python_NewCDescriptorPool(PyObject* ignored, PyObject* args); +PyObject* Python_BuildFile(PyObject* ignored, PyObject* args); +bool InitDescriptor(); +google::protobuf::DescriptorPool* GetDescriptorPool(); + +} // namespace python +} // namespace protobuf + +} // namespace google +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_H__ diff --git a/python/google/protobuf/pyext/extension_dict.cc b/python/google/protobuf/pyext/extension_dict.cc new file mode 100644 index 00000000..1e14b421 --- /dev/null +++ b/python/google/protobuf/pyext/extension_dict.cc @@ -0,0 +1,338 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: anuraag@google.com (Anuraag Agrawal) +// Author: tibell@google.com (Johan Tibell) + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace google { +namespace protobuf { +namespace python { + +extern google::protobuf::DynamicMessageFactory* global_message_factory; + +namespace extension_dict { + +// TODO(tibell): Always use self->message for clarity, just like in +// RepeatedCompositeContainer. +static google::protobuf::Message* GetMessage(ExtensionDict* self) { + if (self->parent != NULL) { + return self->parent->message; + } else { + return self->message; + } +} + +CFieldDescriptor* InternalGetCDescriptorFromExtension(PyObject* extension) { + PyObject* cdescriptor = PyObject_GetAttrString(extension, "_cdescriptor"); + if (cdescriptor == NULL) { + PyErr_SetString(PyExc_KeyError, "Unregistered extension."); + return NULL; + } + if (!PyObject_TypeCheck(cdescriptor, &CFieldDescriptor_Type)) { + PyErr_SetString(PyExc_TypeError, "Not a CFieldDescriptor"); + Py_DECREF(cdescriptor); + return NULL; + } + CFieldDescriptor* descriptor = + reinterpret_cast(cdescriptor); + return descriptor; +} + +PyObject* len(ExtensionDict* self) { +#if PY_MAJOR_VERSION >= 3 + return PyLong_FromLong(PyDict_Size(self->values)); +#else + return PyInt_FromLong(PyDict_Size(self->values)); +#endif +} + +// TODO(tibell): Use VisitCompositeField. +int ReleaseExtension(ExtensionDict* self, + PyObject* extension, + const google::protobuf::FieldDescriptor* descriptor) { + if (descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) { + if (descriptor->cpp_type() == + google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + if (repeated_composite_container::Release( + reinterpret_cast( + extension)) < 0) { + return -1; + } + } else { + if (repeated_scalar_container::Release( + reinterpret_cast( + extension)) < 0) { + return -1; + } + } + } else if (descriptor->cpp_type() == + google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + if (cmessage::ReleaseSubMessage( + GetMessage(self), descriptor, + reinterpret_cast(extension)) < 0) { + return -1; + } + } + + return 0; +} + +PyObject* subscript(ExtensionDict* self, PyObject* key) { + CFieldDescriptor* cdescriptor = InternalGetCDescriptorFromExtension( + key); + if (cdescriptor == NULL) { + return NULL; + } + ScopedPyObjectPtr py_cdescriptor(reinterpret_cast(cdescriptor)); + const google::protobuf::FieldDescriptor* descriptor = cdescriptor->descriptor; + if (descriptor == NULL) { + return NULL; + } + if (descriptor->label() != FieldDescriptor::LABEL_REPEATED && + descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) { + return cmessage::InternalGetScalar(self->parent, descriptor); + } + + PyObject* value = PyDict_GetItem(self->values, key); + if (value != NULL) { + Py_INCREF(value); + return value; + } + + if (descriptor->label() != FieldDescriptor::LABEL_REPEATED && + descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + PyObject* sub_message = cmessage::InternalGetSubMessage( + self->parent, cdescriptor); + if (sub_message == NULL) { + return NULL; + } + PyDict_SetItem(self->values, key, sub_message); + return sub_message; + } + + if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { + if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + // COPIED + PyObject* py_container = PyObject_CallObject( + reinterpret_cast(&RepeatedCompositeContainer_Type), + NULL); + if (py_container == NULL) { + return NULL; + } + RepeatedCompositeContainer* container = + reinterpret_cast(py_container); + PyObject* field = cdescriptor->descriptor_field; + PyObject* message_type = PyObject_GetAttrString(field, "message_type"); + PyObject* concrete_class = PyObject_GetAttrString(message_type, + "_concrete_class"); + container->owner = self->owner; + container->parent = self->parent; + container->message = self->parent->message; + container->parent_field = cdescriptor; + container->subclass_init = concrete_class; + Py_DECREF(message_type); + PyDict_SetItem(self->values, key, py_container); + return py_container; + } else { + // COPIED + ScopedPyObjectPtr init_args(PyTuple_Pack(2, self->parent, cdescriptor)); + PyObject* py_container = PyObject_CallObject( + reinterpret_cast(&RepeatedScalarContainer_Type), + init_args); + if (py_container == NULL) { + return NULL; + } + PyDict_SetItem(self->values, key, py_container); + return py_container; + } + } + PyErr_SetString(PyExc_ValueError, "control reached unexpected line"); + return NULL; +} + +int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) { + CFieldDescriptor* cdescriptor = InternalGetCDescriptorFromExtension( + key); + if (cdescriptor == NULL) { + return -1; + } + ScopedPyObjectPtr py_cdescriptor(reinterpret_cast(cdescriptor)); + const google::protobuf::FieldDescriptor* descriptor = cdescriptor->descriptor; + if (descriptor->label() != FieldDescriptor::LABEL_OPTIONAL || + descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + PyErr_SetString(PyExc_TypeError, "Extension is repeated and/or composite " + "type"); + return -1; + } + cmessage::AssureWritable(self->parent); + if (cmessage::InternalSetScalar(self->parent, descriptor, value) < 0) { + return -1; + } + // TODO(tibell): We shouldn't write scalars to the cache. + PyDict_SetItem(self->values, key, value); + return 0; +} + +PyObject* ClearExtension(ExtensionDict* self, PyObject* extension) { + CFieldDescriptor* cdescriptor = InternalGetCDescriptorFromExtension( + extension); + if (cdescriptor == NULL) { + return NULL; + } + ScopedPyObjectPtr py_cdescriptor(reinterpret_cast(cdescriptor)); + PyObject* value = PyDict_GetItem(self->values, extension); + if (value != NULL) { + if (ReleaseExtension(self, value, cdescriptor->descriptor) < 0) { + return NULL; + } + } + if (cmessage::ClearFieldByDescriptor(self->parent, + cdescriptor->descriptor) == NULL) { + return NULL; + } + if (PyDict_DelItem(self->values, extension) < 0) { + PyErr_Clear(); + } + Py_RETURN_NONE; +} + +PyObject* HasExtension(ExtensionDict* self, PyObject* extension) { + CFieldDescriptor* cdescriptor = InternalGetCDescriptorFromExtension( + extension); + if (cdescriptor == NULL) { + return NULL; + } + ScopedPyObjectPtr py_cdescriptor(reinterpret_cast(cdescriptor)); + PyObject* result = cmessage::HasFieldByDescriptor( + self->parent, cdescriptor->descriptor); + return result; +} + +PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name) { + ScopedPyObjectPtr extensions_by_name(PyObject_GetAttrString( + reinterpret_cast(self->parent), "_extensions_by_name")); + if (extensions_by_name == NULL) { + return NULL; + } + PyObject* result = PyDict_GetItem(extensions_by_name, name); + if (result == NULL) { + Py_RETURN_NONE; + } else { + Py_INCREF(result); + return result; + } +} + +int init(ExtensionDict* self, PyObject* args, PyObject* kwargs) { + self->parent = NULL; + self->message = NULL; + self->values = PyDict_New(); + return 0; +} + +void dealloc(ExtensionDict* self) { + Py_CLEAR(self->values); + self->owner.reset(); + Py_TYPE(self)->tp_free(reinterpret_cast(self)); +} + +static PyMappingMethods MpMethods = { + (lenfunc)len, /* mp_length */ + (binaryfunc)subscript, /* mp_subscript */ + (objobjargproc)ass_subscript,/* mp_ass_subscript */ +}; + +#define EDMETHOD(name, args, doc) { #name, (PyCFunction)name, args, doc } +static PyMethodDef Methods[] = { + EDMETHOD(ClearExtension, METH_O, "Clears an extension from the object."), + EDMETHOD(HasExtension, METH_O, "Checks if the object has an extension."), + EDMETHOD(_FindExtensionByName, METH_O, + "Finds an extension by name."), + { NULL, NULL } +}; + +} // namespace extension_dict + +PyTypeObject ExtensionDict_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + "google.protobuf.internal." + "cpp._message.ExtensionDict", // tp_name + sizeof(ExtensionDict), // tp_basicsize + 0, // tp_itemsize + (destructor)extension_dict::dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + &extension_dict::MpMethods, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "An extension dict", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + extension_dict::Methods, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + (initproc)extension_dict::init, // tp_init +}; + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/python/google/protobuf/pyext/extension_dict.h b/python/google/protobuf/pyext/extension_dict.h new file mode 100644 index 00000000..13430017 --- /dev/null +++ b/python/google/protobuf/pyext/extension_dict.h @@ -0,0 +1,123 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: anuraag@google.com (Anuraag Agrawal) +// Author: tibell@google.com (Johan Tibell) + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_EXTENSION_DICT_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_EXTENSION_DICT_H__ + +#include + +#include +#ifndef _SHARED_PTR_H +#include +#endif + + +namespace google { +namespace protobuf { + +class Message; +class FieldDescriptor; + +using internal::shared_ptr; + +namespace python { + +struct CMessage; +struct CFieldDescriptor; + +typedef struct ExtensionDict { + PyObject_HEAD; + shared_ptr owner; + CMessage* parent; + Message* message; + PyObject* values; +} ExtensionDict; + +extern PyTypeObject ExtensionDict_Type; + +namespace extension_dict { + +// Gets the _cdescriptor reference to a CFieldDescriptor object given a +// python descriptor object. +// +// Returns a new reference. +CFieldDescriptor* InternalGetCDescriptorFromExtension(PyObject* extension); + +// Gets the number of extension values in this ExtensionDict as a python object. +// +// Returns a new reference. +PyObject* len(ExtensionDict* self); + +// Releases extensions referenced outside this dictionary to keep outside +// references alive. +// +// Returns 0 on success, -1 on failure. +int ReleaseExtension(ExtensionDict* self, + PyObject* extension, + const google::protobuf::FieldDescriptor* descriptor); + +// Gets an extension from the dict for the given extension descriptor. +// +// Returns a new reference. +PyObject* subscript(ExtensionDict* self, PyObject* key); + +// Assigns a value to an extension in the dict. Can only be used for singular +// simple types. +// +// Returns 0 on success, -1 on failure. +int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value); + +// Clears an extension from the dict. Will release the extension if there +// is still an external reference left to it. +// +// Returns None on success. +PyObject* ClearExtension(ExtensionDict* self, + PyObject* extension); + +// Checks if the dict has an extension. +// +// Returns a new python boolean reference. +PyObject* HasExtension(ExtensionDict* self, PyObject* extension); + +// Gets an extension from the dict given the extension name as opposed to +// descriptor. +// +// Returns a new reference. +PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name); + +} // namespace extension_dict +} // namespace python +} // namespace protobuf + +} // namespace google +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_EXTENSION_DICT_H__ diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc new file mode 100644 index 00000000..c45cbf0d --- /dev/null +++ b/python/google/protobuf/pyext/message.cc @@ -0,0 +1,2561 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: anuraag@google.com (Anuraag Agrawal) +// Author: tibell@google.com (Johan Tibell) + +#include + +#include +#ifndef _SHARED_PTR_H +#include +#endif +#include +#include + +#ifndef PyVarObject_HEAD_INIT +#define PyVarObject_HEAD_INIT(type, size) PyObject_HEAD_INIT(type) size, +#endif +#ifndef Py_TYPE +#define Py_TYPE(ob) (((PyObject*)(ob))->ob_type) +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if PY_MAJOR_VERSION >= 3 + #define PyInt_Check PyLong_Check + #define PyInt_AsLong PyLong_AsLong + #define PyInt_FromLong PyLong_FromLong + #define PyInt_FromSize_t PyLong_FromSize_t + #define PyString_Check PyUnicode_Check + #define PyString_FromString PyUnicode_FromString + #define PyString_FromStringAndSize PyUnicode_FromStringAndSize + #if PY_VERSION_HEX < 0x03030000 + #error "Python 3.0 - 3.2 are not supported." + #else + #define PyString_AsString(ob) \ + (PyUnicode_Check(ob)? PyUnicode_AsUTF8(ob): PyBytes_AS_STRING(ob)) + #endif +#endif + +namespace google { +namespace protobuf { +namespace python { + +// Forward declarations +namespace cmessage { +static PyObject* GetDescriptor(CMessage* self, PyObject* name); +static string GetMessageName(CMessage* self); +int InternalReleaseFieldByDescriptor( + const google::protobuf::FieldDescriptor* field_descriptor, + PyObject* composite_field, + google::protobuf::Message* parent_message); +} // namespace cmessage + +// --------------------------------------------------------------------- +// Visiting the composite children of a CMessage + +struct ChildVisitor { + // Returns 0 on success, -1 on failure. + int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) { + return 0; + } + + // Returns 0 on success, -1 on failure. + int VisitRepeatedScalarContainer(RepeatedScalarContainer* container) { + return 0; + } + + // Returns 0 on success, -1 on failure. + int VisitCMessage(CMessage* cmessage, + const google::protobuf::FieldDescriptor* field_descriptor) { + return 0; + } +}; + +// Apply a function to a composite field. Does nothing if child is of +// non-composite type. +template +static int VisitCompositeField(const FieldDescriptor* descriptor, + PyObject* child, + Visitor visitor) { + if (descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) { + if (descriptor->cpp_type() == google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + RepeatedCompositeContainer* container = + reinterpret_cast(child); + if (visitor.VisitRepeatedCompositeContainer(container) == -1) + return -1; + } else { + RepeatedScalarContainer* container = + reinterpret_cast(child); + if (visitor.VisitRepeatedScalarContainer(container) == -1) + return -1; + } + } else if (descriptor->cpp_type() == + google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + CMessage* cmsg = reinterpret_cast(child); + if (visitor.VisitCMessage(cmsg, descriptor) == -1) + return -1; + } + // The ExtensionDict might contain non-composite fields, which we + // skip here. + return 0; +} + +// Visit each composite field and extension field of this CMessage. +// Returns -1 on error and 0 on success. +template +int ForEachCompositeField(CMessage* self, Visitor visitor) { + Py_ssize_t pos = 0; + PyObject* key; + PyObject* field; + + // Visit normal fields. + while (PyDict_Next(self->composite_fields, &pos, &key, &field)) { + PyObject* cdescriptor = cmessage::GetDescriptor(self, key); + if (cdescriptor != NULL) { + const google::protobuf::FieldDescriptor* descriptor = + reinterpret_cast(cdescriptor)->descriptor; + if (VisitCompositeField(descriptor, field, visitor) == -1) + return -1; + } + } + + // Visit extension fields. + if (self->extensions != NULL) { + while (PyDict_Next(self->extensions->values, &pos, &key, &field)) { + CFieldDescriptor* cdescriptor = + extension_dict::InternalGetCDescriptorFromExtension(key); + if (cdescriptor == NULL) + return -1; + if (VisitCompositeField(cdescriptor->descriptor, field, visitor) == -1) + return -1; + } + } + + return 0; +} + +// --------------------------------------------------------------------- + +// Constants used for integer type range checking. +PyObject* kPythonZero; +PyObject* kint32min_py; +PyObject* kint32max_py; +PyObject* kuint32max_py; +PyObject* kint64min_py; +PyObject* kint64max_py; +PyObject* kuint64max_py; + +PyObject* EnumTypeWrapper_class; +PyObject* EncodeError_class; +PyObject* DecodeError_class; +PyObject* PickleError_class; + +// Constant PyString values used for GetAttr/GetItem. +static PyObject* kDESCRIPTOR; +static PyObject* k__descriptors; +static PyObject* kfull_name; +static PyObject* kname; +static PyObject* kmessage_type; +static PyObject* kis_extendable; +static PyObject* kextensions_by_name; +static PyObject* k_extensions_by_name; +static PyObject* k_extensions_by_number; +static PyObject* k_concrete_class; +static PyObject* kfields_by_name; + +static CDescriptorPool* descriptor_pool; + +/* Is 64bit */ +void FormatTypeError(PyObject* arg, char* expected_types) { + PyObject* repr = PyObject_Repr(arg); + if (repr) { + PyErr_Format(PyExc_TypeError, + "%.100s has type %.100s, but expected one of: %s", + PyString_AsString(repr), + Py_TYPE(arg)->tp_name, + expected_types); + Py_DECREF(repr); + } +} + +template +bool CheckAndGetInteger( + PyObject* arg, T* value, PyObject* min, PyObject* max) { + bool is_long = PyLong_Check(arg); +#if PY_MAJOR_VERSION < 3 + if (!PyInt_Check(arg) && !is_long) { + FormatTypeError(arg, "int, long"); + return false; + } + if (PyObject_Compare(min, arg) > 0 || PyObject_Compare(max, arg) < 0) { +#else + if (!is_long) { + FormatTypeError(arg, "int"); + return false; + } + if (PyObject_RichCompareBool(min, arg, Py_LE) != 1 || + PyObject_RichCompareBool(max, arg, Py_GE) != 1) { +#endif + PyObject *s = PyObject_Str(arg); + if (s) { + PyErr_Format(PyExc_ValueError, + "Value out of range: %s", + PyString_AsString(s)); + Py_DECREF(s); + } + return false; + } +#if PY_MAJOR_VERSION < 3 + if (!is_long) { + *value = static_cast(PyInt_AsLong(arg)); + } else // NOLINT +#endif + { + if (min == kPythonZero) { + *value = static_cast(PyLong_AsUnsignedLongLong(arg)); + } else { + *value = static_cast(PyLong_AsLongLong(arg)); + } + } + return true; +} + +// These are referenced by repeated_scalar_container, and must +// be explicitly instantiated. +template bool CheckAndGetInteger( + PyObject*, int32*, PyObject*, PyObject*); +template bool CheckAndGetInteger( + PyObject*, int64*, PyObject*, PyObject*); +template bool CheckAndGetInteger( + PyObject*, uint32*, PyObject*, PyObject*); +template bool CheckAndGetInteger( + PyObject*, uint64*, PyObject*, PyObject*); + +bool CheckAndGetDouble(PyObject* arg, double* value) { + if (!PyInt_Check(arg) && !PyLong_Check(arg) && + !PyFloat_Check(arg)) { + FormatTypeError(arg, "int, long, float"); + return false; + } + *value = PyFloat_AsDouble(arg); + return true; +} + +bool CheckAndGetFloat(PyObject* arg, float* value) { + double double_value; + if (!CheckAndGetDouble(arg, &double_value)) { + return false; + } + *value = static_cast(double_value); + return true; +} + +bool CheckAndGetBool(PyObject* arg, bool* value) { + if (!PyInt_Check(arg) && !PyBool_Check(arg) && !PyLong_Check(arg)) { + FormatTypeError(arg, "int, long, bool"); + return false; + } + *value = static_cast(PyInt_AsLong(arg)); + return true; +} + +bool CheckAndSetString( + PyObject* arg, google::protobuf::Message* message, + const google::protobuf::FieldDescriptor* descriptor, + const google::protobuf::Reflection* reflection, + bool append, + int index) { + GOOGLE_DCHECK(descriptor->type() == google::protobuf::FieldDescriptor::TYPE_STRING || + descriptor->type() == google::protobuf::FieldDescriptor::TYPE_BYTES); + if (descriptor->type() == google::protobuf::FieldDescriptor::TYPE_STRING) { + if (!PyBytes_Check(arg) && !PyUnicode_Check(arg)) { + FormatTypeError(arg, "bytes, unicode"); + return false; + } + + if (PyBytes_Check(arg)) { + PyObject* unicode = PyUnicode_FromEncodedObject(arg, "ascii", NULL); + if (unicode == NULL) { + PyObject* repr = PyObject_Repr(arg); + PyErr_Format(PyExc_ValueError, + "%s has type str, but isn't in 7-bit ASCII " + "encoding. Non-ASCII strings must be converted to " + "unicode objects before being added.", + PyString_AsString(repr)); + Py_DECREF(repr); + return false; + } else { + Py_DECREF(unicode); + } + } + } else if (!PyBytes_Check(arg)) { + FormatTypeError(arg, "bytes"); + return false; + } + + PyObject* encoded_string = NULL; + if (descriptor->type() == google::protobuf::FieldDescriptor::TYPE_STRING) { + if (PyBytes_Check(arg)) { +#if PY_MAJOR_VERSION < 3 + encoded_string = PyString_AsEncodedObject(arg, "utf-8", NULL); +#else + encoded_string = arg; // Already encoded. + Py_INCREF(encoded_string); +#endif + } else { + encoded_string = PyUnicode_AsEncodedObject(arg, "utf-8", NULL); + } + } else { + // In this case field type is "bytes". + encoded_string = arg; + Py_INCREF(encoded_string); + } + + if (encoded_string == NULL) { + return false; + } + + char* value; + Py_ssize_t value_len; + if (PyBytes_AsStringAndSize(encoded_string, &value, &value_len) < 0) { + Py_DECREF(encoded_string); + return false; + } + + string value_string(value, value_len); + if (append) { + reflection->AddString(message, descriptor, value_string); + } else if (index < 0) { + reflection->SetString(message, descriptor, value_string); + } else { + reflection->SetRepeatedString(message, descriptor, index, value_string); + } + Py_DECREF(encoded_string); + return true; +} + +PyObject* ToStringObject( + const google::protobuf::FieldDescriptor* descriptor, string value) { + if (descriptor->type() != google::protobuf::FieldDescriptor::TYPE_STRING) { + return PyBytes_FromStringAndSize(value.c_str(), value.length()); + } + + PyObject* result = PyUnicode_DecodeUTF8(value.c_str(), value.length(), NULL); + // If the string can't be decoded in UTF-8, just return a string object that + // contains the raw bytes. This can't happen if the value was assigned using + // the members of the Python message object, but can happen if the values were + // parsed from the wire (binary). + if (result == NULL) { + PyErr_Clear(); + result = PyBytes_FromStringAndSize(value.c_str(), value.length()); + } + return result; +} + +google::protobuf::DynamicMessageFactory* global_message_factory; + +namespace cmessage { + +static int MaybeReleaseOverlappingOneofField( + CMessage* cmessage, + const google::protobuf::FieldDescriptor* field) { +#ifdef GOOGLE_PROTOBUF_HAS_ONEOF + google::protobuf::Message* message = cmessage->message; + const google::protobuf::Reflection* reflection = message->GetReflection(); + if (!field->containing_oneof() || + !reflection->HasOneof(*message, field->containing_oneof()) || + reflection->HasField(*message, field)) { + // No other field in this oneof, no need to release. + return 0; + } + + const OneofDescriptor* oneof = field->containing_oneof(); + const FieldDescriptor* existing_field = + reflection->GetOneofFieldDescriptor(*message, oneof); + if (existing_field->cpp_type() != google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + // Non-message fields don't need to be released. + return 0; + } + const char* field_name = existing_field->name().c_str(); + PyObject* child_message = PyDict_GetItemString( + cmessage->composite_fields, field_name); + if (child_message == NULL) { + // No python reference to this field so no need to release. + return 0; + } + + if (InternalReleaseFieldByDescriptor( + existing_field, child_message, message) < 0) { + return -1; + } + return PyDict_DelItemString(cmessage->composite_fields, field_name); +#else + return 0; +#endif +} + +// --------------------------------------------------------------------- +// Making a message writable + +static google::protobuf::Message* GetMutableMessage( + CMessage* parent, + const google::protobuf::FieldDescriptor* parent_field) { + google::protobuf::Message* parent_message = parent->message; + const google::protobuf::Reflection* reflection = parent_message->GetReflection(); + if (MaybeReleaseOverlappingOneofField(parent, parent_field) < 0) { + return NULL; + } + return reflection->MutableMessage( + parent_message, parent_field, global_message_factory); +} + +struct FixupMessageReference : public ChildVisitor { + // message must outlive this object. + explicit FixupMessageReference(google::protobuf::Message* message) : + message_(message) {} + + int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) { + container->message = message_; + return 0; + } + + int VisitRepeatedScalarContainer(RepeatedScalarContainer* container) { + container->message = message_; + return 0; + } + + private: + google::protobuf::Message* message_; +}; + +int AssureWritable(CMessage* self) { + if (self == NULL || !self->read_only) { + return 0; + } + + if (self->parent == NULL) { + // If parent is NULL but we are trying to modify a read-only message, this + // is a reference to a constant default instance that needs to be replaced + // with a mutable top-level message. + const Message* prototype = global_message_factory->GetPrototype( + self->message->GetDescriptor()); + self->message = prototype->New(); + self->owner.reset(self->message); + } else { + // Otherwise, we need a mutable child message. + if (AssureWritable(self->parent) == -1) + return -1; + + // Make self->message writable. + google::protobuf::Message* parent_message = self->parent->message; + google::protobuf::Message* mutable_message = GetMutableMessage( + self->parent, + self->parent_field->descriptor); + if (mutable_message == NULL) { + return -1; + } + self->message = mutable_message; + } + self->read_only = false; + + // When a CMessage is made writable its Message pointer is updated + // to point to a new mutable Message. When that happens we need to + // update any references to the old, read-only CMessage. There are + // three places such references occur: RepeatedScalarContainer, + // RepeatedCompositeContainer, and ExtensionDict. + if (self->extensions != NULL) + self->extensions->message = self->message; + if (ForEachCompositeField(self, FixupMessageReference(self->message)) == -1) + return -1; + + return 0; +} + +// --- Globals: + +static PyObject* GetDescriptor(CMessage* self, PyObject* name) { + PyObject* descriptors = + PyDict_GetItem(Py_TYPE(self)->tp_dict, k__descriptors); + if (descriptors == NULL) { + PyErr_SetString(PyExc_TypeError, "No __descriptors"); + return NULL; + } + + return PyDict_GetItem(descriptors, name); +} + +static const google::protobuf::Message* CreateMessage(const char* message_type) { + string message_name(message_type); + const google::protobuf::Descriptor* descriptor = + GetDescriptorPool()->FindMessageTypeByName(message_name); + if (descriptor == NULL) { + PyErr_SetString(PyExc_TypeError, message_type); + return NULL; + } + return global_message_factory->GetPrototype(descriptor); +} + +// If cmessage_list is not NULL, this function releases values into the +// container CMessages instead of just removing. Repeated composite container +// needs to do this to make sure CMessages stay alive if they're still +// referenced after deletion. Repeated scalar container doesn't need to worry. +int InternalDeleteRepeatedField( + google::protobuf::Message* message, + const google::protobuf::FieldDescriptor* field_descriptor, + PyObject* slice, + PyObject* cmessage_list) { + Py_ssize_t length, from, to, step, slice_length; + const google::protobuf::Reflection* reflection = message->GetReflection(); + int min, max; + length = reflection->FieldSize(*message, field_descriptor); + + if (PyInt_Check(slice) || PyLong_Check(slice)) { + from = to = PyLong_AsLong(slice); + if (from < 0) { + from = to = length + from; + } + step = 1; + min = max = from; + + // Range check. + if (from < 0 || from >= length) { + PyErr_Format(PyExc_IndexError, "list assignment index out of range"); + return -1; + } + } else if (PySlice_Check(slice)) { + from = to = step = slice_length = 0; + PySlice_GetIndicesEx( +#if PY_MAJOR_VERSION < 3 + reinterpret_cast(slice), +#else + slice, +#endif + length, &from, &to, &step, &slice_length); + if (from < to) { + min = from; + max = to - 1; + } else { + min = to + 1; + max = from; + } + } else { + PyErr_SetString(PyExc_TypeError, "list indices must be integers"); + return -1; + } + + Py_ssize_t i = from; + std::vector to_delete(length, false); + while (i >= min && i <= max) { + to_delete[i] = true; + i += step; + } + + to = 0; + for (i = 0; i < length; ++i) { + if (!to_delete[i]) { + if (i != to) { + reflection->SwapElements(message, field_descriptor, i, to); + if (cmessage_list != NULL) { + // If a list of cmessages is passed in (i.e. from a repeated + // composite container), swap those as well to correspond to the + // swaps in the underlying message so they're in the right order + // when we start releasing. + PyObject* tmp = PyList_GET_ITEM(cmessage_list, i); + PyList_SET_ITEM(cmessage_list, i, + PyList_GET_ITEM(cmessage_list, to)); + PyList_SET_ITEM(cmessage_list, to, tmp); + } + } + ++to; + } + } + + while (i > to) { + if (cmessage_list == NULL) { + reflection->RemoveLast(message, field_descriptor); + } else { + CMessage* last_cmessage = reinterpret_cast( + PyList_GET_ITEM(cmessage_list, PyList_GET_SIZE(cmessage_list) - 1)); + repeated_composite_container::ReleaseLastTo( + field_descriptor, message, last_cmessage); + if (PySequence_DelItem(cmessage_list, -1) < 0) { + return -1; + } + } + --i; + } + + return 0; +} + +int InitAttributes(CMessage* self, PyObject* arg, PyObject* kwargs) { + ScopedPyObjectPtr descriptor; + if (arg == NULL) { + descriptor.reset( + PyObject_GetAttr(reinterpret_cast(self), kDESCRIPTOR)); + if (descriptor == NULL) { + return NULL; + } + } else { + descriptor.reset(arg); + descriptor.inc(); + } + ScopedPyObjectPtr is_extendable(PyObject_GetAttr(descriptor, kis_extendable)); + if (is_extendable == NULL) { + return NULL; + } + int retcode = PyObject_IsTrue(is_extendable); + if (retcode == -1) { + return NULL; + } + if (retcode) { + PyObject* py_extension_dict = PyObject_CallObject( + reinterpret_cast(&ExtensionDict_Type), NULL); + if (py_extension_dict == NULL) { + return NULL; + } + ExtensionDict* extension_dict = reinterpret_cast( + py_extension_dict); + extension_dict->parent = self; + extension_dict->message = self->message; + self->extensions = extension_dict; + } + + if (kwargs == NULL) { + return 0; + } + + Py_ssize_t pos = 0; + PyObject* name; + PyObject* value; + while (PyDict_Next(kwargs, &pos, &name, &value)) { + if (!PyString_Check(name)) { + PyErr_SetString(PyExc_ValueError, "Field name must be a string"); + return -1; + } + PyObject* py_cdescriptor = GetDescriptor(self, name); + if (py_cdescriptor == NULL) { + PyErr_Format(PyExc_ValueError, "Protocol message has no \"%s\" field.", + PyString_AsString(name)); + return -1; + } + const google::protobuf::FieldDescriptor* descriptor = + reinterpret_cast(py_cdescriptor)->descriptor; + if (descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) { + ScopedPyObjectPtr container(GetAttr(self, name)); + if (container == NULL) { + return -1; + } + if (descriptor->cpp_type() == google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + if (repeated_composite_container::Extend( + reinterpret_cast(container.get()), + value) + == NULL) { + return -1; + } + } else { + if (repeated_scalar_container::Extend( + reinterpret_cast(container.get()), + value) == + NULL) { + return -1; + } + } + } else if (descriptor->cpp_type() == + google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + ScopedPyObjectPtr message(GetAttr(self, name)); + if (message == NULL) { + return -1; + } + if (MergeFrom(reinterpret_cast(message.get()), + value) == NULL) { + return -1; + } + } else { + if (SetAttr(self, name, value) < 0) { + return -1; + } + } + } + return 0; +} + +static PyObject* New(PyTypeObject* type, PyObject* args, PyObject* kwargs) { + CMessage* self = reinterpret_cast(type->tp_alloc(type, 0)); + if (self == NULL) { + return NULL; + } + + self->message = NULL; + self->parent = NULL; + self->parent_field = NULL; + self->read_only = false; + self->extensions = NULL; + + self->composite_fields = PyDict_New(); + if (self->composite_fields == NULL) { + return NULL; + } + return reinterpret_cast(self); +} + +PyObject* NewEmpty(PyObject* type) { + return New(reinterpret_cast(type), NULL, NULL); +} + +static int Init(CMessage* self, PyObject* args, PyObject* kwargs) { + if (kwargs == NULL) { + // TODO(anuraag): Set error + return -1; + } + + PyObject* descriptor = PyTuple_GetItem(args, 0); + if (descriptor == NULL || PyTuple_Size(args) != 1) { + PyErr_SetString(PyExc_ValueError, "args must contain one arg: descriptor"); + return -1; + } + + ScopedPyObjectPtr py_message_type(PyObject_GetAttr(descriptor, kfull_name)); + if (py_message_type == NULL) { + return -1; + } + + const char* message_type = PyString_AsString(py_message_type.get()); + const google::protobuf::Message* message = CreateMessage(message_type); + if (message == NULL) { + return -1; + } + + self->message = message->New(); + self->owner.reset(self->message); + + if (InitAttributes(self, descriptor, kwargs) < 0) { + return -1; + } + return 0; +} + +// --------------------------------------------------------------------- +// Deallocating a CMessage +// +// Deallocating a CMessage requires that we clear any weak references +// from children to the message being deallocated. + +// Clear the weak reference from the child to the parent. +struct ClearWeakReferences : public ChildVisitor { + int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) { + container->parent = NULL; + // The elements in the container have the same parent as the + // container itself, so NULL out that pointer as well. + const Py_ssize_t n = PyList_GET_SIZE(container->child_messages); + for (Py_ssize_t i = 0; i < n; ++i) { + CMessage* child_cmessage = reinterpret_cast( + PyList_GET_ITEM(container->child_messages, i)); + child_cmessage->parent = NULL; + } + return 0; + } + + int VisitRepeatedScalarContainer(RepeatedScalarContainer* container) { + container->parent = NULL; + return 0; + } + + int VisitCMessage(CMessage* cmessage, + const google::protobuf::FieldDescriptor* field_descriptor) { + cmessage->parent = NULL; + return 0; + } +}; + +static void Dealloc(CMessage* self) { + // Null out all weak references from children to this message. + GOOGLE_CHECK_EQ(0, ForEachCompositeField(self, ClearWeakReferences())); + + Py_CLEAR(self->extensions); + Py_CLEAR(self->composite_fields); + self->owner.reset(); + Py_TYPE(self)->tp_free(reinterpret_cast(self)); +} + +// --------------------------------------------------------------------- + + +PyObject* IsInitialized(CMessage* self, PyObject* args) { + PyObject* errors = NULL; + if (PyArg_ParseTuple(args, "|O", &errors) < 0) { + return NULL; + } + if (self->message->IsInitialized()) { + Py_RETURN_TRUE; + } + if (errors != NULL) { + ScopedPyObjectPtr initialization_errors( + FindInitializationErrors(self)); + if (initialization_errors == NULL) { + return NULL; + } + ScopedPyObjectPtr extend_name(PyString_FromString("extend")); + if (extend_name == NULL) { + return NULL; + } + ScopedPyObjectPtr result(PyObject_CallMethodObjArgs( + errors, + extend_name.get(), + initialization_errors.get(), + NULL)); + if (result == NULL) { + return NULL; + } + } + Py_RETURN_FALSE; +} + +PyObject* HasFieldByDescriptor( + CMessage* self, const google::protobuf::FieldDescriptor* field_descriptor) { + google::protobuf::Message* message = self->message; + if (!FIELD_BELONGS_TO_MESSAGE(field_descriptor, message)) { + PyErr_SetString(PyExc_KeyError, + "Field does not belong to message!"); + return NULL; + } + if (field_descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) { + PyErr_SetString(PyExc_KeyError, + "Field is repeated. A singular method is required."); + return NULL; + } + bool has_field = + message->GetReflection()->HasField(*message, field_descriptor); + return PyBool_FromLong(has_field ? 1 : 0); +} + +const google::protobuf::FieldDescriptor* FindFieldWithOneofs( + const google::protobuf::Message* message, const char* field_name, bool* in_oneof) { + const google::protobuf::Descriptor* descriptor = message->GetDescriptor(); + const google::protobuf::FieldDescriptor* field_descriptor = + descriptor->FindFieldByName(field_name); + if (field_descriptor == NULL) { + const google::protobuf::OneofDescriptor* oneof_desc = + message->GetDescriptor()->FindOneofByName(field_name); + if (oneof_desc == NULL) { + *in_oneof = false; + return NULL; + } else { + *in_oneof = true; + return message->GetReflection()->GetOneofFieldDescriptor( + *message, oneof_desc); + } + } + return field_descriptor; +} + +PyObject* HasField(CMessage* self, PyObject* arg) { +#if PY_MAJOR_VERSION < 3 + char* field_name; + if (PyString_AsStringAndSize(arg, &field_name, NULL) < 0) { +#else + char* field_name = PyUnicode_AsUTF8(arg); + if (!field_name) { +#endif + return NULL; + } + + google::protobuf::Message* message = self->message; + const google::protobuf::Descriptor* descriptor = message->GetDescriptor(); + bool is_in_oneof; + const google::protobuf::FieldDescriptor* field_descriptor = + FindFieldWithOneofs(message, field_name, &is_in_oneof); + if (field_descriptor == NULL) { + if (!is_in_oneof) { + PyErr_Format(PyExc_ValueError, "Unknown field %s.", field_name); + return NULL; + } else { + Py_RETURN_FALSE; + } + } + + if (field_descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) { + PyErr_Format(PyExc_ValueError, + "Protocol message has no singular \"%s\" field.", field_name); + return NULL; + } + + bool has_field = + message->GetReflection()->HasField(*message, field_descriptor); + if (!has_field && field_descriptor->cpp_type() == + google::protobuf::FieldDescriptor::CPPTYPE_ENUM) { + // We may have an invalid enum value stored in the UnknownFieldSet and need + // to check presence in there as well. + const google::protobuf::UnknownFieldSet& unknown_field_set = + message->GetReflection()->GetUnknownFields(*message); + for (int i = 0; i < unknown_field_set.field_count(); ++i) { + if (unknown_field_set.field(i).number() == field_descriptor->number()) { + Py_RETURN_TRUE; + } + } + Py_RETURN_FALSE; + } + return PyBool_FromLong(has_field ? 1 : 0); +} + +PyObject* ClearExtension(CMessage* self, PyObject* arg) { + if (self->extensions != NULL) { + return extension_dict::ClearExtension(self->extensions, arg); + } + PyErr_SetString(PyExc_TypeError, "Message is not extendable"); + return NULL; +} + +PyObject* HasExtension(CMessage* self, PyObject* arg) { + if (self->extensions != NULL) { + return extension_dict::HasExtension(self->extensions, arg); + } + PyErr_SetString(PyExc_TypeError, "Message is not extendable"); + return NULL; +} + +// --------------------------------------------------------------------- +// Releasing messages +// +// The Python API's ClearField() and Clear() methods behave +// differently than their C++ counterparts. While the C++ versions +// clears the children the Python versions detaches the children, +// without touching their content. This impedance mismatch causes +// some complexity in the implementation, which is captured in this +// section. +// +// When a CMessage field is cleared we need to: +// +// * Release the Message used as the backing store for the CMessage +// from its parent. +// +// * Change the owner field of the released CMessage and all of its +// children to point to the newly released Message. +// +// * Clear the weak references from the released CMessage to the +// parent. +// +// When a RepeatedCompositeContainer field is cleared we need to: +// +// * Release all the Message used as the backing store for the +// CMessages stored in the container. +// +// * Change the owner field of all the released CMessage and all of +// their children to point to the newly released Messages. +// +// * Clear the weak references from the released container to the +// parent. + +struct SetOwnerVisitor : public ChildVisitor { + // new_owner must outlive this object. + explicit SetOwnerVisitor(const shared_ptr& new_owner) + : new_owner_(new_owner) {} + + int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) { + repeated_composite_container::SetOwner(container, new_owner_); + return 0; + } + + int VisitRepeatedScalarContainer(RepeatedScalarContainer* container) { + repeated_scalar_container::SetOwner(container, new_owner_); + return 0; + } + + int VisitCMessage(CMessage* cmessage, + const google::protobuf::FieldDescriptor* field_descriptor) { + return SetOwner(cmessage, new_owner_); + } + + private: + const shared_ptr& new_owner_; +}; + +// Change the owner of this CMessage and all its children, recursively. +int SetOwner(CMessage* self, const shared_ptr& new_owner) { + self->owner = new_owner; + if (ForEachCompositeField(self, SetOwnerVisitor(new_owner)) == -1) + return -1; + return 0; +} + +// Releases the message specified by 'field' and returns the +// pointer. If the field does not exist a new message is created using +// 'descriptor'. The caller takes ownership of the returned pointer. +Message* ReleaseMessage(google::protobuf::Message* message, + const google::protobuf::Descriptor* descriptor, + const google::protobuf::FieldDescriptor* field_descriptor) { + Message* released_message = message->GetReflection()->ReleaseMessage( + message, field_descriptor, global_message_factory); + // ReleaseMessage will return NULL which differs from + // child_cmessage->message, if the field does not exist. In this case, + // the latter points to the default instance via a const_cast<>, so we + // have to reset it to a new mutable object since we are taking ownership. + if (released_message == NULL) { + const Message* prototype = global_message_factory->GetPrototype( + descriptor); + GOOGLE_DCHECK(prototype != NULL); + released_message = prototype->New(); + } + + return released_message; +} + +int ReleaseSubMessage(google::protobuf::Message* message, + const google::protobuf::FieldDescriptor* field_descriptor, + CMessage* child_cmessage) { + // Release the Message + shared_ptr released_message(ReleaseMessage( + message, child_cmessage->message->GetDescriptor(), field_descriptor)); + child_cmessage->message = released_message.get(); + child_cmessage->owner.swap(released_message); + child_cmessage->parent = NULL; + child_cmessage->parent_field = NULL; + child_cmessage->read_only = false; + return ForEachCompositeField(child_cmessage, + SetOwnerVisitor(child_cmessage->owner)); +} + +struct ReleaseChild : public ChildVisitor { + // message must outlive this object. + explicit ReleaseChild(google::protobuf::Message* parent_message) : + parent_message_(parent_message) {} + + int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) { + return repeated_composite_container::Release( + reinterpret_cast(container)); + } + + int VisitRepeatedScalarContainer(RepeatedScalarContainer* container) { + return repeated_scalar_container::Release( + reinterpret_cast(container)); + } + + int VisitCMessage(CMessage* cmessage, + const google::protobuf::FieldDescriptor* field_descriptor) { + return ReleaseSubMessage(parent_message_, field_descriptor, + reinterpret_cast(cmessage)); + } + + google::protobuf::Message* parent_message_; +}; + +int InternalReleaseFieldByDescriptor( + const google::protobuf::FieldDescriptor* field_descriptor, + PyObject* composite_field, + google::protobuf::Message* parent_message) { + return VisitCompositeField( + field_descriptor, + composite_field, + ReleaseChild(parent_message)); +} + +int InternalReleaseField(CMessage* self, PyObject* composite_field, + PyObject* name) { + PyObject* cdescriptor = GetDescriptor(self, name); + if (cdescriptor != NULL) { + const google::protobuf::FieldDescriptor* descriptor = + reinterpret_cast(cdescriptor)->descriptor; + return InternalReleaseFieldByDescriptor( + descriptor, composite_field, self->message); + } + + return 0; +} + +PyObject* ClearFieldByDescriptor( + CMessage* self, + const google::protobuf::FieldDescriptor* descriptor) { + if (!FIELD_BELONGS_TO_MESSAGE(descriptor, self->message)) { + PyErr_SetString(PyExc_KeyError, + "Field does not belong to message!"); + return NULL; + } + AssureWritable(self); + self->message->GetReflection()->ClearField(self->message, descriptor); + Py_RETURN_NONE; +} + +PyObject* ClearField(CMessage* self, PyObject* arg) { + char* field_name; + if (!PyString_Check(arg)) { + PyErr_SetString(PyExc_TypeError, "field name must be a string"); + return NULL; + } +#if PY_MAJOR_VERSION < 3 + if (PyString_AsStringAndSize(arg, &field_name, NULL) < 0) { + return NULL; + } +#else + field_name = PyUnicode_AsUTF8(arg); +#endif + AssureWritable(self); + google::protobuf::Message* message = self->message; + const google::protobuf::Descriptor* descriptor = message->GetDescriptor(); + ScopedPyObjectPtr arg_in_oneof; + bool is_in_oneof; + const google::protobuf::FieldDescriptor* field_descriptor = + FindFieldWithOneofs(message, field_name, &is_in_oneof); + if (field_descriptor == NULL) { + if (!is_in_oneof) { + PyErr_Format(PyExc_ValueError, + "Protocol message has no \"%s\" field.", field_name); + return NULL; + } else { + Py_RETURN_NONE; + } + } else if (is_in_oneof) { + arg_in_oneof.reset(PyString_FromString(field_descriptor->name().c_str())); + arg = arg_in_oneof.get(); + } + + PyObject* composite_field = PyDict_GetItem(self->composite_fields, + arg); + + // Only release the field if there's a possibility that there are + // references to it. + if (composite_field != NULL) { + if (InternalReleaseField(self, composite_field, arg) < 0) { + return NULL; + } + PyDict_DelItem(self->composite_fields, arg); + } + message->GetReflection()->ClearField(message, field_descriptor); + if (field_descriptor->cpp_type() == google::protobuf::FieldDescriptor::CPPTYPE_ENUM) { + google::protobuf::UnknownFieldSet* unknown_field_set = + message->GetReflection()->MutableUnknownFields(message); + unknown_field_set->DeleteByNumber(field_descriptor->number()); + } + + Py_RETURN_NONE; +} + +PyObject* Clear(CMessage* self) { + AssureWritable(self); + if (ForEachCompositeField(self, ReleaseChild(self->message)) == -1) + return NULL; + + // The old ExtensionDict still aliases this CMessage, but all its + // fields have been released. + if (self->extensions != NULL) { + Py_CLEAR(self->extensions); + PyObject* py_extension_dict = PyObject_CallObject( + reinterpret_cast(&ExtensionDict_Type), NULL); + if (py_extension_dict == NULL) { + return NULL; + } + ExtensionDict* extension_dict = reinterpret_cast( + py_extension_dict); + extension_dict->parent = self; + extension_dict->message = self->message; + self->extensions = extension_dict; + } + PyDict_Clear(self->composite_fields); + self->message->Clear(); + Py_RETURN_NONE; +} + +// --------------------------------------------------------------------- + +static string GetMessageName(CMessage* self) { + if (self->parent_field != NULL) { + return self->parent_field->descriptor->full_name(); + } else { + return self->message->GetDescriptor()->full_name(); + } +} + +static PyObject* SerializeToString(CMessage* self, PyObject* args) { + if (!self->message->IsInitialized()) { + ScopedPyObjectPtr errors(FindInitializationErrors(self)); + if (errors == NULL) { + return NULL; + } + ScopedPyObjectPtr comma(PyString_FromString(",")); + if (comma == NULL) { + return NULL; + } + ScopedPyObjectPtr joined( + PyObject_CallMethod(comma.get(), "join", "O", errors.get())); + if (joined == NULL) { + return NULL; + } + PyErr_Format(EncodeError_class, "Message %s is missing required fields: %s", + GetMessageName(self).c_str(), PyString_AsString(joined.get())); + return NULL; + } + int size = self->message->ByteSize(); + if (size <= 0) { + return PyBytes_FromString(""); + } + PyObject* result = PyBytes_FromStringAndSize(NULL, size); + if (result == NULL) { + return NULL; + } + char* buffer = PyBytes_AS_STRING(result); + self->message->SerializeWithCachedSizesToArray( + reinterpret_cast(buffer)); + return result; +} + +static PyObject* SerializePartialToString(CMessage* self) { + string contents; + self->message->SerializePartialToString(&contents); + return PyBytes_FromStringAndSize(contents.c_str(), contents.size()); +} + +// Formats proto fields for ascii dumps using python formatting functions where +// appropriate. +class PythonFieldValuePrinter : public google::protobuf::TextFormat::FieldValuePrinter { + public: + PythonFieldValuePrinter() : float_holder_(PyFloat_FromDouble(0)) {} + + // Python has some differences from C++ when printing floating point numbers. + // + // 1) Trailing .0 is always printed. + // 2) Outputted is rounded to 12 digits. + // + // We override floating point printing with the C-API function for printing + // Python floats to ensure consistency. + string PrintFloat(float value) const { return PrintDouble(value); } + string PrintDouble(double value) const { + reinterpret_cast(float_holder_.get())->ob_fval = value; + ScopedPyObjectPtr s(PyObject_Str(float_holder_.get())); + if (s == NULL) return string(); +#if PY_MAJOR_VERSION < 3 + char *cstr = PyBytes_AS_STRING(static_cast(s)); +#else + char *cstr = PyUnicode_AsUTF8(s); +#endif + return string(cstr); + } + + private: + // Holder for a python float object which we use to allow us to use + // the Python API for printing doubles. We initialize once and then + // directly modify it for every float printed to save on allocations + // and refcounting. + ScopedPyObjectPtr float_holder_; +}; + +static PyObject* ToStr(CMessage* self) { + google::protobuf::TextFormat::Printer printer; + // Passes ownership + printer.SetDefaultFieldValuePrinter(new PythonFieldValuePrinter()); + printer.SetHideUnknownFields(true); + string output; + if (!printer.PrintToString(*self->message, &output)) { + PyErr_SetString(PyExc_ValueError, "Unable to convert message to str"); + return NULL; + } + return PyString_FromString(output.c_str()); +} + +PyObject* MergeFrom(CMessage* self, PyObject* arg) { + CMessage* other_message; + if (!PyObject_TypeCheck(reinterpret_cast(arg), &CMessage_Type)) { + PyErr_SetString(PyExc_TypeError, "Must be a message"); + return NULL; + } + + other_message = reinterpret_cast(arg); + if (other_message->message->GetDescriptor() != + self->message->GetDescriptor()) { + PyErr_Format(PyExc_TypeError, + "Tried to merge from a message with a different type. " + "to: %s, from: %s", + self->message->GetDescriptor()->full_name().c_str(), + other_message->message->GetDescriptor()->full_name().c_str()); + return NULL; + } + AssureWritable(self); + + // TODO(tibell): Message::MergeFrom might turn some child Messages + // into mutable messages, invalidating the message field in the + // corresponding CMessages. We should run a FixupMessageReferences + // pass here. + + self->message->MergeFrom(*other_message->message); + Py_RETURN_NONE; +} + +static PyObject* CopyFrom(CMessage* self, PyObject* arg) { + CMessage* other_message; + if (!PyObject_TypeCheck(reinterpret_cast(arg), &CMessage_Type)) { + PyErr_SetString(PyExc_TypeError, "Must be a message"); + return NULL; + } + + other_message = reinterpret_cast(arg); + + if (self == other_message) { + Py_RETURN_NONE; + } + + if (other_message->message->GetDescriptor() != + self->message->GetDescriptor()) { + PyErr_Format(PyExc_TypeError, + "Tried to copy from a message with a different type. " + "to: %s, from: %s", + self->message->GetDescriptor()->full_name().c_str(), + other_message->message->GetDescriptor()->full_name().c_str()); + return NULL; + } + + AssureWritable(self); + + // CopyFrom on the message will not clean up self->composite_fields, + // which can leave us in an inconsistent state, so clear it out here. + Clear(self); + + self->message->CopyFrom(*other_message->message); + + Py_RETURN_NONE; +} + +static PyObject* MergeFromString(CMessage* self, PyObject* arg) { + const void* data; + Py_ssize_t data_length; + if (PyObject_AsReadBuffer(arg, &data, &data_length) < 0) { + return NULL; + } + + AssureWritable(self); + google::protobuf::io::CodedInputStream input( + reinterpret_cast(data), data_length); + input.SetExtensionRegistry(GetDescriptorPool(), global_message_factory); + bool success = self->message->MergePartialFromCodedStream(&input); + if (success) { + return PyInt_FromLong(input.CurrentPosition()); + } else { + PyErr_Format(DecodeError_class, "Error parsing message"); + return NULL; + } +} + +static PyObject* ParseFromString(CMessage* self, PyObject* arg) { + if (Clear(self) == NULL) { + return NULL; + } + return MergeFromString(self, arg); +} + +static PyObject* ByteSize(CMessage* self, PyObject* args) { + return PyLong_FromLong(self->message->ByteSize()); +} + +static PyObject* RegisterExtension(PyObject* cls, + PyObject* extension_handle) { + ScopedPyObjectPtr message_descriptor(PyObject_GetAttr(cls, kDESCRIPTOR)); + if (message_descriptor == NULL) { + return NULL; + } + if (PyObject_SetAttrString(extension_handle, "containing_type", + message_descriptor) < 0) { + return NULL; + } + ScopedPyObjectPtr extensions_by_name( + PyObject_GetAttr(cls, k_extensions_by_name)); + if (extensions_by_name == NULL) { + PyErr_SetString(PyExc_TypeError, "no extensions_by_name on class"); + return NULL; + } + ScopedPyObjectPtr full_name(PyObject_GetAttr(extension_handle, kfull_name)); + if (full_name == NULL) { + return NULL; + } + if (PyDict_SetItem(extensions_by_name, full_name, extension_handle) < 0) { + return NULL; + } + + // Also store a mapping from extension number to implementing class. + ScopedPyObjectPtr extensions_by_number( + PyObject_GetAttr(cls, k_extensions_by_number)); + if (extensions_by_number == NULL) { + PyErr_SetString(PyExc_TypeError, "no extensions_by_number on class"); + return NULL; + } + ScopedPyObjectPtr number(PyObject_GetAttrString(extension_handle, "number")); + if (number == NULL) { + return NULL; + } + if (PyDict_SetItem(extensions_by_number, number, extension_handle) < 0) { + return NULL; + } + + CFieldDescriptor* cdescriptor = + extension_dict::InternalGetCDescriptorFromExtension(extension_handle); + ScopedPyObjectPtr py_cdescriptor(reinterpret_cast(cdescriptor)); + if (cdescriptor == NULL) { + return NULL; + } + Py_INCREF(extension_handle); + cdescriptor->descriptor_field = extension_handle; + const google::protobuf::FieldDescriptor* descriptor = cdescriptor->descriptor; + // Check if it's a message set + if (descriptor->is_extension() && + descriptor->containing_type()->options().message_set_wire_format() && + descriptor->type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE && + descriptor->message_type() == descriptor->extension_scope() && + descriptor->label() == google::protobuf::FieldDescriptor::LABEL_OPTIONAL) { + ScopedPyObjectPtr message_name(PyString_FromStringAndSize( + descriptor->message_type()->full_name().c_str(), + descriptor->message_type()->full_name().size())); + if (message_name == NULL) { + return NULL; + } + PyDict_SetItem(extensions_by_name, message_name, extension_handle); + } + + Py_RETURN_NONE; +} + +static PyObject* SetInParent(CMessage* self, PyObject* args) { + AssureWritable(self); + Py_RETURN_NONE; +} + +static PyObject* WhichOneof(CMessage* self, PyObject* arg) { + char* oneof_name; + if (!PyString_Check(arg)) { + PyErr_SetString(PyExc_TypeError, "field name must be a string"); + return NULL; + } + oneof_name = PyString_AsString(arg); + if (oneof_name == NULL) { + return NULL; + } + const google::protobuf::OneofDescriptor* oneof_desc = + self->message->GetDescriptor()->FindOneofByName(oneof_name); + if (oneof_desc == NULL) { + PyErr_Format(PyExc_ValueError, + "Protocol message has no oneof \"%s\" field.", oneof_name); + return NULL; + } + const google::protobuf::FieldDescriptor* field_in_oneof = + self->message->GetReflection()->GetOneofFieldDescriptor( + *self->message, oneof_desc); + if (field_in_oneof == NULL) { + Py_RETURN_NONE; + } else { + return PyString_FromString(field_in_oneof->name().c_str()); + } +} + +static PyObject* ListFields(CMessage* self) { + vector fields; + self->message->GetReflection()->ListFields(*self->message, &fields); + + PyObject* descriptor = PyDict_GetItem(Py_TYPE(self)->tp_dict, kDESCRIPTOR); + if (descriptor == NULL) { + return NULL; + } + ScopedPyObjectPtr fields_by_name( + PyObject_GetAttr(descriptor, kfields_by_name)); + if (fields_by_name == NULL) { + return NULL; + } + ScopedPyObjectPtr extensions_by_name(PyObject_GetAttr( + reinterpret_cast(Py_TYPE(self)), k_extensions_by_name)); + if (extensions_by_name == NULL) { + PyErr_SetString(PyExc_ValueError, "no extensionsbyname"); + return NULL; + } + // Normally, the list will be exactly the size of the fields. + PyObject* all_fields = PyList_New(fields.size()); + if (all_fields == NULL) { + return NULL; + } + + // When there are unknown extensions, the py list will *not* contain + // the field information. Thus the actual size of the py list will be + // smaller than the size of fields. Set the actual size at the end. + Py_ssize_t actual_size = 0; + for (Py_ssize_t i = 0; i < fields.size(); ++i) { + ScopedPyObjectPtr t(PyTuple_New(2)); + if (t == NULL) { + Py_DECREF(all_fields); + return NULL; + } + + if (fields[i]->is_extension()) { + const string& field_name = fields[i]->full_name(); + PyObject* extension_field = PyDict_GetItemString(extensions_by_name, + field_name.c_str()); + if (extension_field == NULL) { + // If we couldn't fetch extension_field, it means the module that + // defines this extension has not been explicitly imported in Python + // code, and the extension hasn't been registered. There's nothing much + // we can do about this, so just skip it in the output to match the + // behavior of the python implementation. + continue; + } + PyObject* extensions = reinterpret_cast(self->extensions); + if (extensions == NULL) { + Py_DECREF(all_fields); + return NULL; + } + // 'extension' reference later stolen by PyTuple_SET_ITEM. + PyObject* extension = PyObject_GetItem(extensions, extension_field); + if (extension == NULL) { + Py_DECREF(all_fields); + return NULL; + } + Py_INCREF(extension_field); + PyTuple_SET_ITEM(t.get(), 0, extension_field); + // Steals reference to 'extension' + PyTuple_SET_ITEM(t.get(), 1, extension); + } else { + const string& field_name = fields[i]->name(); + ScopedPyObjectPtr py_field_name(PyString_FromStringAndSize( + field_name.c_str(), field_name.length())); + if (py_field_name == NULL) { + PyErr_SetString(PyExc_ValueError, "bad string"); + Py_DECREF(all_fields); + return NULL; + } + PyObject* field_descriptor = + PyDict_GetItem(fields_by_name, py_field_name); + if (field_descriptor == NULL) { + Py_DECREF(all_fields); + return NULL; + } + + PyObject* field_value = GetAttr(self, py_field_name); + if (field_value == NULL) { + PyErr_SetObject(PyExc_ValueError, py_field_name); + Py_DECREF(all_fields); + return NULL; + } + Py_INCREF(field_descriptor); + PyTuple_SET_ITEM(t.get(), 0, field_descriptor); + PyTuple_SET_ITEM(t.get(), 1, field_value); + } + PyList_SET_ITEM(all_fields, actual_size, t.release()); + ++actual_size; + } + Py_SIZE(all_fields) = actual_size; + return all_fields; +} + +PyObject* FindInitializationErrors(CMessage* self) { + google::protobuf::Message* message = self->message; + vector errors; + message->FindInitializationErrors(&errors); + + PyObject* error_list = PyList_New(errors.size()); + if (error_list == NULL) { + return NULL; + } + for (Py_ssize_t i = 0; i < errors.size(); ++i) { + const string& error = errors[i]; + PyObject* error_string = PyString_FromStringAndSize( + error.c_str(), error.length()); + if (error_string == NULL) { + Py_DECREF(error_list); + return NULL; + } + PyList_SET_ITEM(error_list, i, error_string); + } + return error_list; +} + +static PyObject* RichCompare(CMessage* self, PyObject* other, int opid) { + if (!PyObject_TypeCheck(other, &CMessage_Type)) { + if (opid == Py_EQ) { + Py_RETURN_FALSE; + } else if (opid == Py_NE) { + Py_RETURN_TRUE; + } + } + if (opid == Py_EQ || opid == Py_NE) { + ScopedPyObjectPtr self_fields(ListFields(self)); + ScopedPyObjectPtr other_fields(ListFields( + reinterpret_cast(other))); + return PyObject_RichCompare(self_fields, other_fields, opid); + } else { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } +} + +PyObject* InternalGetScalar( + CMessage* self, + const google::protobuf::FieldDescriptor* field_descriptor) { + google::protobuf::Message* message = self->message; + const google::protobuf::Reflection* reflection = message->GetReflection(); + + if (!FIELD_BELONGS_TO_MESSAGE(field_descriptor, message)) { + PyErr_SetString( + PyExc_KeyError, "Field does not belong to message!"); + return NULL; + } + + PyObject* result = NULL; + switch (field_descriptor->cpp_type()) { + case google::protobuf::FieldDescriptor::CPPTYPE_INT32: { + int32 value = reflection->GetInt32(*message, field_descriptor); + result = PyInt_FromLong(value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_INT64: { + int64 value = reflection->GetInt64(*message, field_descriptor); + result = PyLong_FromLongLong(value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: { + uint32 value = reflection->GetUInt32(*message, field_descriptor); + result = PyInt_FromSize_t(value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: { + uint64 value = reflection->GetUInt64(*message, field_descriptor); + result = PyLong_FromUnsignedLongLong(value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: { + float value = reflection->GetFloat(*message, field_descriptor); + result = PyFloat_FromDouble(value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: { + double value = reflection->GetDouble(*message, field_descriptor); + result = PyFloat_FromDouble(value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: { + bool value = reflection->GetBool(*message, field_descriptor); + result = PyBool_FromLong(value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { + string value = reflection->GetString(*message, field_descriptor); + result = ToStringObject(field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { + if (!message->GetReflection()->HasField(*message, field_descriptor)) { + // Look for the value in the unknown fields. + google::protobuf::UnknownFieldSet* unknown_field_set = + message->GetReflection()->MutableUnknownFields(message); + for (int i = 0; i < unknown_field_set->field_count(); ++i) { + if (unknown_field_set->field(i).number() == + field_descriptor->number()) { + result = PyInt_FromLong(unknown_field_set->field(i).varint()); + break; + } + } + } + + if (result == NULL) { + const google::protobuf::EnumValueDescriptor* enum_value = + message->GetReflection()->GetEnum(*message, field_descriptor); + result = PyInt_FromLong(enum_value->number()); + } + break; + } + default: + PyErr_Format( + PyExc_SystemError, "Getting a value from a field of unknown type %d", + field_descriptor->cpp_type()); + } + + return result; +} + +PyObject* InternalGetSubMessage(CMessage* self, + CFieldDescriptor* cfield_descriptor) { + PyObject* field = cfield_descriptor->descriptor_field; + ScopedPyObjectPtr message_type(PyObject_GetAttr(field, kmessage_type)); + if (message_type == NULL) { + return NULL; + } + ScopedPyObjectPtr concrete_class( + PyObject_GetAttr(message_type, k_concrete_class)); + if (concrete_class == NULL) { + return NULL; + } + PyObject* py_cmsg = cmessage::NewEmpty(concrete_class); + if (py_cmsg == NULL) { + return NULL; + } + if (!PyObject_TypeCheck(py_cmsg, &CMessage_Type)) { + PyErr_SetString(PyExc_TypeError, "Not a CMessage!"); + } + CMessage* cmsg = reinterpret_cast(py_cmsg); + + const google::protobuf::FieldDescriptor* field_descriptor = + cfield_descriptor->descriptor; + const google::protobuf::Reflection* reflection = self->message->GetReflection(); + const google::protobuf::Message& sub_message = reflection->GetMessage( + *self->message, field_descriptor, global_message_factory); + cmsg->owner = self->owner; + cmsg->parent = self; + cmsg->parent_field = cfield_descriptor; + cmsg->read_only = !reflection->HasField(*self->message, field_descriptor); + cmsg->message = const_cast(&sub_message); + + if (InitAttributes(cmsg, NULL, NULL) < 0) { + Py_DECREF(py_cmsg); + return NULL; + } + return py_cmsg; +} + +int InternalSetScalar( + CMessage* self, + const google::protobuf::FieldDescriptor* field_descriptor, + PyObject* arg) { + google::protobuf::Message* message = self->message; + const google::protobuf::Reflection* reflection = message->GetReflection(); + + if (!FIELD_BELONGS_TO_MESSAGE(field_descriptor, message)) { + PyErr_SetString( + PyExc_KeyError, "Field does not belong to message!"); + return -1; + } + + if (MaybeReleaseOverlappingOneofField(self, field_descriptor) < 0) { + return -1; + } + + switch (field_descriptor->cpp_type()) { + case google::protobuf::FieldDescriptor::CPPTYPE_INT32: { + GOOGLE_CHECK_GET_INT32(arg, value, -1); + reflection->SetInt32(message, field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_INT64: { + GOOGLE_CHECK_GET_INT64(arg, value, -1); + reflection->SetInt64(message, field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: { + GOOGLE_CHECK_GET_UINT32(arg, value, -1); + reflection->SetUInt32(message, field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: { + GOOGLE_CHECK_GET_UINT64(arg, value, -1); + reflection->SetUInt64(message, field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: { + GOOGLE_CHECK_GET_FLOAT(arg, value, -1); + reflection->SetFloat(message, field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: { + GOOGLE_CHECK_GET_DOUBLE(arg, value, -1); + reflection->SetDouble(message, field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: { + GOOGLE_CHECK_GET_BOOL(arg, value, -1); + reflection->SetBool(message, field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { + if (!CheckAndSetString( + arg, message, field_descriptor, reflection, false, -1)) { + return -1; + } + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { + GOOGLE_CHECK_GET_INT32(arg, value, -1); + const google::protobuf::EnumDescriptor* enum_descriptor = + field_descriptor->enum_type(); + const google::protobuf::EnumValueDescriptor* enum_value = + enum_descriptor->FindValueByNumber(value); + if (enum_value != NULL) { + reflection->SetEnum(message, field_descriptor, enum_value); + } else { + PyErr_Format(PyExc_ValueError, "Unknown enum value: %d", value); + return -1; + } + break; + } + default: + PyErr_Format( + PyExc_SystemError, "Setting value to a field of unknown type %d", + field_descriptor->cpp_type()); + return -1; + } + + return 0; +} + +PyObject* FromString(PyTypeObject* cls, PyObject* serialized) { + PyObject* py_cmsg = PyObject_CallObject( + reinterpret_cast(cls), NULL); + if (py_cmsg == NULL) { + return NULL; + } + CMessage* cmsg = reinterpret_cast(py_cmsg); + + ScopedPyObjectPtr py_length(MergeFromString(cmsg, serialized)); + if (py_length == NULL) { + Py_DECREF(py_cmsg); + return NULL; + } + + if (InitAttributes(cmsg, NULL, NULL) < 0) { + Py_DECREF(py_cmsg); + return NULL; + } + return py_cmsg; +} + +static PyObject* AddDescriptors(PyTypeObject* cls, + PyObject* descriptor) { + if (PyObject_SetAttr(reinterpret_cast(cls), + k_extensions_by_name, PyDict_New()) < 0) { + return NULL; + } + if (PyObject_SetAttr(reinterpret_cast(cls), + k_extensions_by_number, PyDict_New()) < 0) { + return NULL; + } + + ScopedPyObjectPtr field_descriptors(PyDict_New()); + + ScopedPyObjectPtr fields(PyObject_GetAttrString(descriptor, "fields")); + if (fields == NULL) { + return NULL; + } + + ScopedPyObjectPtr _NUMBER_string(PyString_FromString("_FIELD_NUMBER")); + if (_NUMBER_string == NULL) { + return NULL; + } + + const Py_ssize_t fields_size = PyList_GET_SIZE(fields.get()); + for (int i = 0; i < fields_size; ++i) { + PyObject* field = PyList_GET_ITEM(fields.get(), i); + ScopedPyObjectPtr field_name(PyObject_GetAttr(field, kname)); + ScopedPyObjectPtr full_field_name(PyObject_GetAttr(field, kfull_name)); + if (field_name == NULL || full_field_name == NULL) { + PyErr_SetString(PyExc_TypeError, "Name is null"); + return NULL; + } + + PyObject* field_descriptor = + cdescriptor_pool::FindFieldByName(descriptor_pool, full_field_name); + if (field_descriptor == NULL) { + PyErr_SetString(PyExc_TypeError, "Couldn't find field"); + return NULL; + } + Py_INCREF(field); + CFieldDescriptor* cfield_descriptor = reinterpret_cast( + field_descriptor); + cfield_descriptor->descriptor_field = field; + if (PyDict_SetItem(field_descriptors, field_name, field_descriptor) < 0) { + return NULL; + } + + // The FieldDescriptor's name field might either be of type bytes or + // of type unicode, depending on whether the FieldDescriptor was + // parsed from a serialized message or read from the + // _pb2.py module. + ScopedPyObjectPtr field_name_upcased( + PyObject_CallMethod(field_name, "upper", NULL)); + if (field_name_upcased == NULL) { + return NULL; + } + + ScopedPyObjectPtr field_number_name(PyObject_CallMethod( + field_name_upcased, "__add__", "(O)", _NUMBER_string.get())); + if (field_number_name == NULL) { + return NULL; + } + + ScopedPyObjectPtr number(PyInt_FromLong( + cfield_descriptor->descriptor->number())); + if (number == NULL) { + return NULL; + } + if (PyObject_SetAttr(reinterpret_cast(cls), + field_number_name, number) == -1) { + return NULL; + } + } + + PyDict_SetItem(cls->tp_dict, k__descriptors, field_descriptors); + + // Enum Values + ScopedPyObjectPtr enum_types(PyObject_GetAttrString(descriptor, + "enum_types")); + if (enum_types == NULL) { + return NULL; + } + ScopedPyObjectPtr type_iter(PyObject_GetIter(enum_types)); + if (type_iter == NULL) { + return NULL; + } + ScopedPyObjectPtr enum_type; + while ((enum_type.reset(PyIter_Next(type_iter))) != NULL) { + ScopedPyObjectPtr wrapped(PyObject_CallFunctionObjArgs( + EnumTypeWrapper_class, enum_type.get(), NULL)); + if (wrapped == NULL) { + return NULL; + } + ScopedPyObjectPtr enum_name(PyObject_GetAttr(enum_type, kname)); + if (enum_name == NULL) { + return NULL; + } + if (PyObject_SetAttr(reinterpret_cast(cls), + enum_name, wrapped) == -1) { + return NULL; + } + + ScopedPyObjectPtr enum_values(PyObject_GetAttrString(enum_type, "values")); + if (enum_values == NULL) { + return NULL; + } + ScopedPyObjectPtr values_iter(PyObject_GetIter(enum_values)); + if (values_iter == NULL) { + return NULL; + } + ScopedPyObjectPtr enum_value; + while ((enum_value.reset(PyIter_Next(values_iter))) != NULL) { + ScopedPyObjectPtr value_name(PyObject_GetAttr(enum_value, kname)); + if (value_name == NULL) { + return NULL; + } + ScopedPyObjectPtr value_number(PyObject_GetAttrString(enum_value, + "number")); + if (value_number == NULL) { + return NULL; + } + if (PyObject_SetAttr(reinterpret_cast(cls), + value_name, value_number) == -1) { + return NULL; + } + } + if (PyErr_Occurred()) { // If PyIter_Next failed + return NULL; + } + } + if (PyErr_Occurred()) { // If PyIter_Next failed + return NULL; + } + + ScopedPyObjectPtr extension_dict( + PyObject_GetAttr(descriptor, kextensions_by_name)); + if (extension_dict == NULL || !PyDict_Check(extension_dict)) { + PyErr_SetString(PyExc_TypeError, "extensions_by_name not a dict"); + return NULL; + } + Py_ssize_t pos = 0; + PyObject* extension_name; + PyObject* extension_field; + + while (PyDict_Next(extension_dict, &pos, &extension_name, &extension_field)) { + if (PyObject_SetAttr(reinterpret_cast(cls), + extension_name, extension_field) == -1) { + return NULL; + } + ScopedPyObjectPtr py_cfield_descriptor( + PyObject_GetAttrString(extension_field, "_cdescriptor")); + if (py_cfield_descriptor == NULL) { + return NULL; + } + CFieldDescriptor* cfield_descriptor = + reinterpret_cast(py_cfield_descriptor.get()); + Py_INCREF(extension_field); + cfield_descriptor->descriptor_field = extension_field; + + ScopedPyObjectPtr field_name_upcased( + PyObject_CallMethod(extension_name, "upper", NULL)); + if (field_name_upcased == NULL) { + return NULL; + } + ScopedPyObjectPtr field_number_name(PyObject_CallMethod( + field_name_upcased, "__add__", "(O)", _NUMBER_string.get())); + if (field_number_name == NULL) { + return NULL; + } + ScopedPyObjectPtr number(PyInt_FromLong( + cfield_descriptor->descriptor->number())); + if (number == NULL) { + return NULL; + } + if (PyObject_SetAttr(reinterpret_cast(cls), + field_number_name, PyInt_FromLong( + cfield_descriptor->descriptor->number())) == -1) { + return NULL; + } + } + + Py_RETURN_NONE; +} + +PyObject* DeepCopy(CMessage* self, PyObject* arg) { + PyObject* clone = PyObject_CallObject( + reinterpret_cast(Py_TYPE(self)), NULL); + if (clone == NULL) { + return NULL; + } + if (!PyObject_TypeCheck(clone, &CMessage_Type)) { + Py_DECREF(clone); + return NULL; + } + if (InitAttributes(reinterpret_cast(clone), NULL, NULL) < 0) { + Py_DECREF(clone); + return NULL; + } + if (MergeFrom(reinterpret_cast(clone), + reinterpret_cast(self)) == NULL) { + Py_DECREF(clone); + return NULL; + } + return clone; +} + +PyObject* ToUnicode(CMessage* self) { + // Lazy import to prevent circular dependencies + ScopedPyObjectPtr text_format( + PyImport_ImportModule("google.protobuf.text_format")); + if (text_format == NULL) { + return NULL; + } + ScopedPyObjectPtr method_name(PyString_FromString("MessageToString")); + if (method_name == NULL) { + return NULL; + } + Py_INCREF(Py_True); + ScopedPyObjectPtr encoded(PyObject_CallMethodObjArgs(text_format, method_name, + self, Py_True, NULL)); + Py_DECREF(Py_True); + if (encoded == NULL) { + return NULL; + } +#if PY_MAJOR_VERSION < 3 + PyObject* decoded = PyString_AsDecodedObject(encoded, "utf-8", NULL); +#else + PyObject* decoded = PyUnicode_FromEncodedObject(encoded, "utf-8", NULL); +#endif + if (decoded == NULL) { + return NULL; + } + return decoded; +} + +PyObject* Reduce(CMessage* self) { + ScopedPyObjectPtr constructor(reinterpret_cast(Py_TYPE(self))); + constructor.inc(); + ScopedPyObjectPtr args(PyTuple_New(0)); + if (args == NULL) { + return NULL; + } + ScopedPyObjectPtr state(PyDict_New()); + if (state == NULL) { + return NULL; + } + ScopedPyObjectPtr serialized(SerializePartialToString(self)); + if (serialized == NULL) { + return NULL; + } + if (PyDict_SetItemString(state, "serialized", serialized) < 0) { + return NULL; + } + return Py_BuildValue("OOO", constructor.get(), args.get(), state.get()); +} + +PyObject* SetState(CMessage* self, PyObject* state) { + if (!PyDict_Check(state)) { + PyErr_SetString(PyExc_TypeError, "state not a dict"); + return NULL; + } + PyObject* serialized = PyDict_GetItemString(state, "serialized"); + if (serialized == NULL) { + return NULL; + } + if (ParseFromString(self, serialized) == NULL) { + return NULL; + } + Py_RETURN_NONE; +} + +// CMessage static methods: +PyObject* _GetFieldDescriptor(PyObject* unused, PyObject* arg) { + return cdescriptor_pool::FindFieldByName(descriptor_pool, arg); +} + +PyObject* _GetExtensionDescriptor(PyObject* unused, PyObject* arg) { + return cdescriptor_pool::FindExtensionByName(descriptor_pool, arg); +} + +static PyMemberDef Members[] = { + {"Extensions", T_OBJECT_EX, offsetof(CMessage, extensions), 0, + "Extension dict"}, + {NULL} +}; + +static PyMethodDef Methods[] = { + { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS, + "Makes a deep copy of the class." }, + { "__reduce__", (PyCFunction)Reduce, METH_NOARGS, + "Outputs picklable representation of the message." }, + { "__setstate__", (PyCFunction)SetState, METH_O, + "Inputs picklable representation of the message." }, + { "__unicode__", (PyCFunction)ToUnicode, METH_NOARGS, + "Outputs a unicode representation of the message." }, + { "AddDescriptors", (PyCFunction)AddDescriptors, METH_O | METH_CLASS, + "Adds field descriptors to the class" }, + { "ByteSize", (PyCFunction)ByteSize, METH_NOARGS, + "Returns the size of the message in bytes." }, + { "Clear", (PyCFunction)Clear, METH_NOARGS, + "Clears the message." }, + { "ClearExtension", (PyCFunction)ClearExtension, METH_O, + "Clears a message field." }, + { "ClearField", (PyCFunction)ClearField, METH_O, + "Clears a message field." }, + { "CopyFrom", (PyCFunction)CopyFrom, METH_O, + "Copies a protocol message into the current message." }, + { "FindInitializationErrors", (PyCFunction)FindInitializationErrors, + METH_NOARGS, + "Finds unset required fields." }, + { "FromString", (PyCFunction)FromString, METH_O | METH_CLASS, + "Creates new method instance from given serialized data." }, + { "HasExtension", (PyCFunction)HasExtension, METH_O, + "Checks if a message field is set." }, + { "HasField", (PyCFunction)HasField, METH_O, + "Checks if a message field is set." }, + { "IsInitialized", (PyCFunction)IsInitialized, METH_VARARGS, + "Checks if all required fields of a protocol message are set." }, + { "ListFields", (PyCFunction)ListFields, METH_NOARGS, + "Lists all set fields of a message." }, + { "MergeFrom", (PyCFunction)MergeFrom, METH_O, + "Merges a protocol message into the current message." }, + { "MergeFromString", (PyCFunction)MergeFromString, METH_O, + "Merges a serialized message into the current message." }, + { "ParseFromString", (PyCFunction)ParseFromString, METH_O, + "Parses a serialized message into the current message." }, + { "RegisterExtension", (PyCFunction)RegisterExtension, METH_O | METH_CLASS, + "Registers an extension with the current message." }, + { "SerializePartialToString", (PyCFunction)SerializePartialToString, + METH_NOARGS, + "Serializes the message to a string, even if it isn't initialized." }, + { "SerializeToString", (PyCFunction)SerializeToString, METH_NOARGS, + "Serializes the message to a string, only for initialized messages." }, + { "SetInParent", (PyCFunction)SetInParent, METH_NOARGS, + "Sets the has bit of the given field in its parent message." }, + { "WhichOneof", (PyCFunction)WhichOneof, METH_O, + "Returns the name of the field set inside a oneof, " + "or None if no field is set." }, + + // Static Methods. + { "_BuildFile", (PyCFunction)Python_BuildFile, METH_O | METH_STATIC, + "Registers a new protocol buffer file in the global C++ descriptor pool." }, + { "_GetFieldDescriptor", (PyCFunction)_GetFieldDescriptor, + METH_O | METH_STATIC, "Finds a field descriptor in the message pool." }, + { "_GetExtensionDescriptor", (PyCFunction)_GetExtensionDescriptor, + METH_O | METH_STATIC, + "Finds a extension descriptor in the message pool." }, + { NULL, NULL} +}; + +PyObject* GetAttr(CMessage* self, PyObject* name) { + PyObject* value = PyDict_GetItem(self->composite_fields, name); + if (value != NULL) { + Py_INCREF(value); + return value; + } + + PyObject* descriptor = GetDescriptor(self, name); + if (descriptor != NULL) { + CFieldDescriptor* cdescriptor = + reinterpret_cast(descriptor); + const google::protobuf::FieldDescriptor* field_descriptor = cdescriptor->descriptor; + if (field_descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) { + if (field_descriptor->cpp_type() == + google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + PyObject* py_container = PyObject_CallObject( + reinterpret_cast(&RepeatedCompositeContainer_Type), + NULL); + if (py_container == NULL) { + return NULL; + } + RepeatedCompositeContainer* container = + reinterpret_cast(py_container); + PyObject* field = cdescriptor->descriptor_field; + PyObject* message_type = PyObject_GetAttr(field, kmessage_type); + if (message_type == NULL) { + return NULL; + } + PyObject* concrete_class = + PyObject_GetAttr(message_type, k_concrete_class); + if (concrete_class == NULL) { + return NULL; + } + container->parent = self; + container->parent_field = cdescriptor; + container->message = self->message; + container->owner = self->owner; + container->subclass_init = concrete_class; + Py_DECREF(message_type); + if (PyDict_SetItem(self->composite_fields, name, py_container) < 0) { + Py_DECREF(py_container); + return NULL; + } + return py_container; + } else { + ScopedPyObjectPtr init_args(PyTuple_Pack(2, self, cdescriptor)); + PyObject* py_container = PyObject_CallObject( + reinterpret_cast(&RepeatedScalarContainer_Type), + init_args); + if (py_container == NULL) { + return NULL; + } + if (PyDict_SetItem(self->composite_fields, name, py_container) < 0) { + Py_DECREF(py_container); + return NULL; + } + return py_container; + } + } else { + if (field_descriptor->cpp_type() == + google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + PyObject* sub_message = InternalGetSubMessage(self, cdescriptor); + if (PyDict_SetItem(self->composite_fields, name, sub_message) < 0) { + Py_DECREF(sub_message); + return NULL; + } + return sub_message; + } else { + return InternalGetScalar(self, field_descriptor); + } + } + } + + return CMessage_Type.tp_base->tp_getattro(reinterpret_cast(self), + name); +} + +int SetAttr(CMessage* self, PyObject* name, PyObject* value) { + if (PyDict_Contains(self->composite_fields, name)) { + PyErr_SetString(PyExc_TypeError, "Can't set composite field"); + return -1; + } + + PyObject* descriptor = GetDescriptor(self, name); + if (descriptor != NULL) { + AssureWritable(self); + CFieldDescriptor* cdescriptor = + reinterpret_cast(descriptor); + const google::protobuf::FieldDescriptor* field_descriptor = cdescriptor->descriptor; + if (field_descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) { + PyErr_Format(PyExc_AttributeError, "Assignment not allowed to repeated " + "field \"%s\" in protocol message object.", + field_descriptor->name().c_str()); + return -1; + } else { + if (field_descriptor->cpp_type() == + google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + PyErr_Format(PyExc_AttributeError, "Assignment not allowed to " + "field \"%s\" in protocol message object.", + field_descriptor->name().c_str()); + return -1; + } else { + return InternalSetScalar(self, field_descriptor, value); + } + } + } + + PyErr_Format(PyExc_AttributeError, "Assignment not allowed"); + return -1; +} + +} // namespace cmessage + +PyTypeObject CMessage_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + "google.protobuf.internal." + "cpp._message.CMessage", // tp_name + sizeof(CMessage), // tp_basicsize + 0, // tp_itemsize + (destructor)cmessage::Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + (reprfunc)cmessage::ToStr, // tp_str + (getattrofunc)cmessage::GetAttr, // tp_getattro + (setattrofunc)cmessage::SetAttr, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, // tp_flags + "A ProtocolMessage", // tp_doc + 0, // tp_traverse + 0, // tp_clear + (richcmpfunc)cmessage::RichCompare, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + cmessage::Methods, // tp_methods + cmessage::Members, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + (initproc)cmessage::Init, // tp_init + 0, // tp_alloc + cmessage::New, // tp_new +}; + +// --- Exposing the C proto living inside Python proto to C code: + +const Message* (*GetCProtoInsidePyProtoPtr)(PyObject* msg); +Message* (*MutableCProtoInsidePyProtoPtr)(PyObject* msg); + +static const google::protobuf::Message* GetCProtoInsidePyProtoImpl(PyObject* msg) { + if (!PyObject_TypeCheck(msg, &CMessage_Type)) { + return NULL; + } + CMessage* cmsg = reinterpret_cast(msg); + return cmsg->message; +} + +static google::protobuf::Message* MutableCProtoInsidePyProtoImpl(PyObject* msg) { + if (!PyObject_TypeCheck(msg, &CMessage_Type)) { + return NULL; + } + CMessage* cmsg = reinterpret_cast(msg); + if (PyDict_Size(cmsg->composite_fields) != 0 || + (cmsg->extensions != NULL && + PyDict_Size(cmsg->extensions->values) != 0)) { + // There is currently no way of accurately syncing arbitrary changes to + // the underlying C++ message back to the CMessage (e.g. removed repeated + // composite containers). We only allow direct mutation of the underlying + // C++ message if there is no child data in the CMessage. + return NULL; + } + cmessage::AssureWritable(cmsg); + return cmsg->message; +} + +static const char module_docstring[] = +"python-proto2 is a module that can be used to enhance proto2 Python API\n" +"performance.\n" +"\n" +"It provides access to the protocol buffers C++ reflection API that\n" +"implements the basic protocol buffer functions."; + +void InitGlobals() { + // TODO(gps): Check all return values in this function for NULL and propagate + // the error (MemoryError) on up to result in an import failure. These should + // also be freed and reset to NULL during finalization. + kPythonZero = PyInt_FromLong(0); + kint32min_py = PyInt_FromLong(kint32min); + kint32max_py = PyInt_FromLong(kint32max); + kuint32max_py = PyLong_FromLongLong(kuint32max); + kint64min_py = PyLong_FromLongLong(kint64min); + kint64max_py = PyLong_FromLongLong(kint64max); + kuint64max_py = PyLong_FromUnsignedLongLong(kuint64max); + + kDESCRIPTOR = PyString_FromString("DESCRIPTOR"); + k__descriptors = PyString_FromString("__descriptors"); + kfull_name = PyString_FromString("full_name"); + kis_extendable = PyString_FromString("is_extendable"); + kextensions_by_name = PyString_FromString("extensions_by_name"); + k_extensions_by_name = PyString_FromString("_extensions_by_name"); + k_extensions_by_number = PyString_FromString("_extensions_by_number"); + k_concrete_class = PyString_FromString("_concrete_class"); + kmessage_type = PyString_FromString("message_type"); + kname = PyString_FromString("name"); + kfields_by_name = PyString_FromString("fields_by_name"); + + global_message_factory = new DynamicMessageFactory(GetDescriptorPool()); + global_message_factory->SetDelegateToGeneratedFactory(true); + + descriptor_pool = reinterpret_cast( + Python_NewCDescriptorPool(NULL, NULL)); +} + +bool InitProto2MessageModule(PyObject *m) { + InitGlobals(); + + google::protobuf::python::CMessage_Type.tp_hash = PyObject_HashNotImplemented; + if (PyType_Ready(&google::protobuf::python::CMessage_Type) < 0) { + return false; + } + + // All three of these are actually set elsewhere, directly onto the child + // protocol buffer message class, but set them here as well to document that + // subclasses need to set these. + PyDict_SetItem(google::protobuf::python::CMessage_Type.tp_dict, kDESCRIPTOR, Py_None); + PyDict_SetItem(google::protobuf::python::CMessage_Type.tp_dict, + k_extensions_by_name, Py_None); + PyDict_SetItem(google::protobuf::python::CMessage_Type.tp_dict, + k_extensions_by_number, Py_None); + + PyModule_AddObject(m, "Message", reinterpret_cast( + &google::protobuf::python::CMessage_Type)); + + google::protobuf::python::RepeatedScalarContainer_Type.tp_new = PyType_GenericNew; + google::protobuf::python::RepeatedScalarContainer_Type.tp_hash = + PyObject_HashNotImplemented; + if (PyType_Ready(&google::protobuf::python::RepeatedScalarContainer_Type) < 0) { + return false; + } + + PyModule_AddObject(m, "RepeatedScalarContainer", + reinterpret_cast( + &google::protobuf::python::RepeatedScalarContainer_Type)); + + google::protobuf::python::RepeatedCompositeContainer_Type.tp_new = PyType_GenericNew; + google::protobuf::python::RepeatedCompositeContainer_Type.tp_hash = + PyObject_HashNotImplemented; + if (PyType_Ready(&google::protobuf::python::RepeatedCompositeContainer_Type) < 0) { + return false; + } + + PyModule_AddObject( + m, "RepeatedCompositeContainer", + reinterpret_cast( + &google::protobuf::python::RepeatedCompositeContainer_Type)); + + google::protobuf::python::ExtensionDict_Type.tp_new = PyType_GenericNew; + google::protobuf::python::ExtensionDict_Type.tp_hash = PyObject_HashNotImplemented; + if (PyType_Ready(&google::protobuf::python::ExtensionDict_Type) < 0) { + return false; + } + + PyModule_AddObject( + m, "ExtensionDict", + reinterpret_cast(&google::protobuf::python::ExtensionDict_Type)); + + if (!google::protobuf::python::InitDescriptor()) { + return false; + } + + PyObject* enum_type_wrapper = PyImport_ImportModule( + "google.protobuf.internal.enum_type_wrapper"); + if (enum_type_wrapper == NULL) { + return false; + } + google::protobuf::python::EnumTypeWrapper_class = + PyObject_GetAttrString(enum_type_wrapper, "EnumTypeWrapper"); + Py_DECREF(enum_type_wrapper); + + PyObject* message_module = PyImport_ImportModule( + "google.protobuf.message"); + if (message_module == NULL) { + return false; + } + google::protobuf::python::EncodeError_class = PyObject_GetAttrString(message_module, + "EncodeError"); + google::protobuf::python::DecodeError_class = PyObject_GetAttrString(message_module, + "DecodeError"); + Py_DECREF(message_module); + + PyObject* pickle_module = PyImport_ImportModule("pickle"); + if (pickle_module == NULL) { + return false; + } + google::protobuf::python::PickleError_class = PyObject_GetAttrString(pickle_module, + "PickleError"); + Py_DECREF(pickle_module); + + // Override {Get,Mutable}CProtoInsidePyProto. + google::protobuf::python::GetCProtoInsidePyProtoPtr = + google::protobuf::python::GetCProtoInsidePyProtoImpl; + google::protobuf::python::MutableCProtoInsidePyProtoPtr = + google::protobuf::python::MutableCProtoInsidePyProtoImpl; + + return true; +} + +} // namespace python +} // namespace protobuf + + +#if PY_MAJOR_VERSION >= 3 +static struct PyModuleDef _module = { + PyModuleDef_HEAD_INIT, + "_message", + google::protobuf::python::module_docstring, + -1, + NULL, + NULL, + NULL, + NULL, + NULL +}; +#define INITFUNC PyInit__message +#define INITFUNC_ERRORVAL NULL +#else // Python 2 +#define INITFUNC init_message +#define INITFUNC_ERRORVAL +#endif + +extern "C" { + PyMODINIT_FUNC INITFUNC(void) { + PyObject* m; +#if PY_MAJOR_VERSION >= 3 + m = PyModule_Create(&_module); +#else + m = Py_InitModule3("_message", NULL, google::protobuf::python::module_docstring); +#endif + if (m == NULL) { + return INITFUNC_ERRORVAL; + } + + if (!google::protobuf::python::InitProto2MessageModule(m)) { + Py_DECREF(m); + return INITFUNC_ERRORVAL; + } + +#if PY_MAJOR_VERSION >= 3 + return m; +#endif + } +} +} // namespace google diff --git a/python/google/protobuf/pyext/message.h b/python/google/protobuf/pyext/message.h new file mode 100644 index 00000000..28e504f0 --- /dev/null +++ b/python/google/protobuf/pyext/message.h @@ -0,0 +1,305 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: anuraag@google.com (Anuraag Agrawal) +// Author: tibell@google.com (Johan Tibell) + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_H__ + +#include + +#include +#ifndef _SHARED_PTR_H +#include +#endif +#include + + +namespace google { +namespace protobuf { + +class Message; +class Reflection; +class FieldDescriptor; + +using internal::shared_ptr; + +namespace python { + +struct CFieldDescriptor; +struct ExtensionDict; + +typedef struct CMessage { + PyObject_HEAD; + + // This is the top-level C++ Message object that owns the whole + // proto tree. Every Python CMessage holds a reference to it in + // order to keep it alive as long as there's a Python object that + // references any part of the tree. + shared_ptr owner; + + // Weak reference to a parent CMessage object. This is NULL for any top-level + // message and is set for any child message (i.e. a child submessage or a + // part of a repeated composite field). + // + // Used to make sure all ancestors are also mutable when first modifying + // a child submessage (in other words, turning a default message instance + // into a mutable one). + // + // If a submessage is released (becomes a new top-level message), this field + // MUST be set to NULL. The parent may get deallocated and further attempts + // to use this pointer will result in a crash. + struct CMessage* parent; + + // Weak reference to the parent's descriptor that describes this submessage. + // Used together with the parent's message when making a default message + // instance mutable. + // TODO(anuraag): With a bit of work on the Python/C++ layer, it should be + // possible to make this a direct pointer to a C++ FieldDescriptor, this would + // be easier if this implementation replaces upstream. + CFieldDescriptor* parent_field; + + // Pointer to the C++ Message object for this CMessage. The + // CMessage does not own this pointer. + Message* message; + + // Indicates this submessage is pointing to a default instance of a message. + // Submessages are always first created as read only messages and are then + // made writable, at which point this field is set to false. + bool read_only; + + // A reference to a Python dictionary containing CMessage, + // RepeatedCompositeContainer, and RepeatedScalarContainer + // objects. Used as a cache to make sure we don't have to make a + // Python wrapper for the C++ Message objects on every access, or + // deal with the synchronization nightmare that could create. + PyObject* composite_fields; + + // A reference to the dictionary containing the message's extensions. + // Similar to composite_fields, acting as a cache, but also contains the + // required extension dict logic. + ExtensionDict* extensions; +} CMessage; + +extern PyTypeObject CMessage_Type; + +namespace cmessage { + +// Create a new empty message that can be populated by the parent. +PyObject* NewEmpty(PyObject* type); + +// Release a submessage from its proto tree, making it a new top-level messgae. +// A new message will be created if this is a read-only default instance. +// +// Corresponds to reflection api method ReleaseMessage. +int ReleaseSubMessage(google::protobuf::Message* message, + const google::protobuf::FieldDescriptor* field_descriptor, + CMessage* child_cmessage); + +// Initializes a new CMessage instance for a submessage. Only called once per +// submessage as the result is cached in composite_fields. +// +// Corresponds to reflection api method GetMessage. +PyObject* InternalGetSubMessage(CMessage* self, + CFieldDescriptor* cfield_descriptor); + +// Deletes a range of C++ submessages in a repeated field (following a +// removal in a RepeatedCompositeContainer). +// +// Releases messages to the provided cmessage_list if it is not NULL rather +// than just removing them from the underlying proto. This cmessage_list must +// have a CMessage for each underlying submessage. The CMessages refered to +// by slice will be removed from cmessage_list by this function. +// +// Corresponds to reflection api method RemoveLast. +int InternalDeleteRepeatedField(google::protobuf::Message* message, + const google::protobuf::FieldDescriptor* field_descriptor, + PyObject* slice, PyObject* cmessage_list); + +// Sets the specified scalar value to the message. +int InternalSetScalar(CMessage* self, + const google::protobuf::FieldDescriptor* field_descriptor, + PyObject* value); + +// Retrieves the specified scalar value from the message. +// +// Returns a new python reference. +PyObject* InternalGetScalar(CMessage* self, + const google::protobuf::FieldDescriptor* field_descriptor); + +// Clears the message, removing all contained data. Extension dictionary and +// submessages are released first if there are remaining external references. +// +// Corresponds to message api method Clear. +PyObject* Clear(CMessage* self); + +// Clears the data described by the given descriptor. Used to clear extensions +// (which don't have names). Extension release is handled by ExtensionDict +// class, not this function. +// TODO(anuraag): Try to make this discrepancy in release semantics with +// ClearField less confusing. +// +// Corresponds to reflection api method ClearField. +PyObject* ClearFieldByDescriptor( + CMessage* self, + const google::protobuf::FieldDescriptor* descriptor); + +// Clears the data for the given field name. The message is released if there +// are any external references. +// +// Corresponds to reflection api method ClearField. +PyObject* ClearField(CMessage* self, PyObject* arg); + +// Checks if the message has the field described by the descriptor. Used for +// extensions (which have no name). +// +// Corresponds to reflection api method HasField +PyObject* HasFieldByDescriptor( + CMessage* self, const google::protobuf::FieldDescriptor* field_descriptor); + +// Checks if the message has the named field. +// +// Corresponds to reflection api method HasField. +PyObject* HasField(CMessage* self, PyObject* arg); + +// Initializes constants/enum values on a message. This is called by +// RepeatedCompositeContainer and ExtensionDict after calling the constructor. +// TODO(anuraag): Make it always called from within the constructor since it can +int InitAttributes(CMessage* self, PyObject* descriptor, PyObject* kwargs); + +PyObject* MergeFrom(CMessage* self, PyObject* arg); + +// Retrieves an attribute named 'name' from CMessage 'self'. Returns +// the attribute value on success, or NULL on failure. +// +// Returns a new reference. +PyObject* GetAttr(CMessage* self, PyObject* name); + +// Set the value of the attribute named 'name', for CMessage 'self', +// to the value 'value'. Returns -1 on failure. +int SetAttr(CMessage* self, PyObject* name, PyObject* value); + +PyObject* FindInitializationErrors(CMessage* self); + +// Set the owner field of self and any children of self, recursively. +// Used when self is being released and thus has a new owner (the +// released Message.) +int SetOwner(CMessage* self, const shared_ptr& new_owner); + +int AssureWritable(CMessage* self); + +} // namespace cmessage + +/* Is 64bit */ +#define IS_64BIT (SIZEOF_LONG == 8) + +#define FIELD_BELONGS_TO_MESSAGE(field_descriptor, message) \ + ((message)->GetDescriptor() == (field_descriptor)->containing_type()) + +#define FIELD_IS_REPEATED(field_descriptor) \ + ((field_descriptor)->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) + +#define GOOGLE_CHECK_GET_INT32(arg, value, err) \ + int32 value; \ + if (!CheckAndGetInteger(arg, &value, kint32min_py, kint32max_py)) { \ + return err; \ + } + +#define GOOGLE_CHECK_GET_INT64(arg, value, err) \ + int64 value; \ + if (!CheckAndGetInteger(arg, &value, kint64min_py, kint64max_py)) { \ + return err; \ + } + +#define GOOGLE_CHECK_GET_UINT32(arg, value, err) \ + uint32 value; \ + if (!CheckAndGetInteger(arg, &value, kPythonZero, kuint32max_py)) { \ + return err; \ + } + +#define GOOGLE_CHECK_GET_UINT64(arg, value, err) \ + uint64 value; \ + if (!CheckAndGetInteger(arg, &value, kPythonZero, kuint64max_py)) { \ + return err; \ + } + +#define GOOGLE_CHECK_GET_FLOAT(arg, value, err) \ + float value; \ + if (!CheckAndGetFloat(arg, &value)) { \ + return err; \ + } \ + +#define GOOGLE_CHECK_GET_DOUBLE(arg, value, err) \ + double value; \ + if (!CheckAndGetDouble(arg, &value)) { \ + return err; \ + } + +#define GOOGLE_CHECK_GET_BOOL(arg, value, err) \ + bool value; \ + if (!CheckAndGetBool(arg, &value)) { \ + return err; \ + } + + +extern PyObject* kPythonZero; +extern PyObject* kint32min_py; +extern PyObject* kint32max_py; +extern PyObject* kuint32max_py; +extern PyObject* kint64min_py; +extern PyObject* kint64max_py; +extern PyObject* kuint64max_py; + +#define C(str) const_cast(str) + +void FormatTypeError(PyObject* arg, char* expected_types); +template +bool CheckAndGetInteger( + PyObject* arg, T* value, PyObject* min, PyObject* max); +bool CheckAndGetDouble(PyObject* arg, double* value); +bool CheckAndGetFloat(PyObject* arg, float* value); +bool CheckAndGetBool(PyObject* arg, bool* value); +bool CheckAndSetString( + PyObject* arg, google::protobuf::Message* message, + const google::protobuf::FieldDescriptor* descriptor, + const google::protobuf::Reflection* reflection, + bool append, + int index); +PyObject* ToStringObject( + const google::protobuf::FieldDescriptor* descriptor, string value); + +extern PyObject* PickleError_class; + +} // namespace python +} // namespace protobuf + +} // namespace google +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_H__ diff --git a/python/google/protobuf/pyext/proto2_api_test.proto b/python/google/protobuf/pyext/proto2_api_test.proto new file mode 100644 index 00000000..eef9b730 --- /dev/null +++ b/python/google/protobuf/pyext/proto2_api_test.proto @@ -0,0 +1,38 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import "google/protobuf/internal/cpp/proto1_api_test.proto"; + +package google.protobuf.python.internal; + +message TestNestedProto1APIMessage { + optional int32 a = 1; + optional TestMessage.NestedMessage b = 2; +} diff --git a/python/google/protobuf/pyext/python.proto b/python/google/protobuf/pyext/python.proto new file mode 100644 index 00000000..ee6d5abe --- /dev/null +++ b/python/google/protobuf/pyext/python.proto @@ -0,0 +1,66 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: tibell@google.com (Johan Tibell) +// +// These message definitions are used to exercises known corner cases +// in the C++ implementation of the Python API. + + +package google.protobuf.python.internal; + +// Protos optimized for SPEED use a strict superset of the generated code +// of equivalent ones optimized for CODE_SIZE, so we should optimize all our +// tests for speed unless explicitly testing code size optimization. +option optimize_for = SPEED; + +message TestAllTypes { + message NestedMessage { + optional int32 bb = 1; + optional ForeignMessage cc = 2; + } + + repeated NestedMessage repeated_nested_message = 1; + optional NestedMessage optional_nested_message = 2; + optional int32 optional_int32 = 3; +} + +message ForeignMessage { + optional int32 c = 1; + repeated int32 d = 2; +} + +message TestAllExtensions { + extensions 1 to max; +} + +extend TestAllExtensions { + optional TestAllTypes.NestedMessage optional_nested_message_extension = 1; +} diff --git a/python/google/protobuf/pyext/python_protobuf.h b/python/google/protobuf/pyext/python_protobuf.h new file mode 100644 index 00000000..c5b0b1cd --- /dev/null +++ b/python/google/protobuf/pyext/python_protobuf.h @@ -0,0 +1,57 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: qrczak@google.com (Marcin Kowalczyk) +// +// This module exposes the C proto inside the given Python proto, in +// case the Python proto is implemented with a C proto. + +#ifndef GOOGLE_PROTOBUF_PYTHON_PYTHON_PROTOBUF_H__ +#define GOOGLE_PROTOBUF_PYTHON_PYTHON_PROTOBUF_H__ + +#include + +namespace google { +namespace protobuf { + +class Message; + +namespace python { + +// Return the pointer to the C proto inside the given Python proto, +// or NULL when this is not a Python proto implemented with a C proto. +const Message* GetCProtoInsidePyProto(PyObject* msg); +Message* MutableCProtoInsidePyProto(PyObject* msg); + +} // namespace python +} // namespace protobuf + +} // namespace google +#endif // GOOGLE_PROTOBUF_PYTHON_PYTHON_PROTOBUF_H__ diff --git a/python/google/protobuf/pyext/repeated_composite_container.cc b/python/google/protobuf/pyext/repeated_composite_container.cc new file mode 100644 index 00000000..b1645050 --- /dev/null +++ b/python/google/protobuf/pyext/repeated_composite_container.cc @@ -0,0 +1,763 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: anuraag@google.com (Anuraag Agrawal) +// Author: tibell@google.com (Johan Tibell) + +#include + +#include +#ifndef _SHARED_PTR_H +#include +#endif + +#include +#include +#include +#include +#include +#include +#include + +#if PY_MAJOR_VERSION >= 3 + #define PyInt_Check PyLong_Check + #define PyInt_AsLong PyLong_AsLong + #define PyInt_FromLong PyLong_FromLong +#endif + +namespace google { +namespace protobuf { +namespace python { + +extern google::protobuf::DynamicMessageFactory* global_message_factory; + +namespace repeated_composite_container { + +// TODO(tibell): We might also want to check: +// GOOGLE_CHECK_NOTNULL((self)->owner.get()); +#define GOOGLE_CHECK_ATTACHED(self) \ + do { \ + GOOGLE_CHECK_NOTNULL((self)->message); \ + GOOGLE_CHECK_NOTNULL((self)->parent_field); \ + } while (0); + +#define GOOGLE_CHECK_RELEASED(self) \ + do { \ + GOOGLE_CHECK((self)->owner.get() == NULL); \ + GOOGLE_CHECK((self)->message == NULL); \ + GOOGLE_CHECK((self)->parent_field == NULL); \ + GOOGLE_CHECK((self)->parent == NULL); \ + } while (0); + +// Returns a new reference. +static PyObject* GetKey(PyObject* x) { + // Just the identity function. + Py_INCREF(x); + return x; +} + +#define GET_KEY(keyfunc, value) \ + ((keyfunc) == NULL ? \ + GetKey((value)) : \ + PyObject_CallFunctionObjArgs((keyfunc), (value), NULL)) + +// Converts a comparison function that returns -1, 0, or 1 into a +// less-than predicate. +// +// Returns -1 on error, 1 if x < y, 0 if x >= y. +static int islt(PyObject *x, PyObject *y, PyObject *compare) { + if (compare == NULL) + return PyObject_RichCompareBool(x, y, Py_LT); + + ScopedPyObjectPtr res(PyObject_CallFunctionObjArgs(compare, x, y, NULL)); + if (res == NULL) + return -1; + if (!PyInt_Check(res)) { + PyErr_Format(PyExc_TypeError, + "comparison function must return int, not %.200s", + Py_TYPE(res)->tp_name); + return -1; + } + return PyInt_AsLong(res) < 0; +} + +// Copied from uarrsort.c but swaps memcpy swaps with protobuf/python swaps +// TODO(anuraag): Is there a better way to do this then reinventing the wheel? +static int InternalQuickSort(RepeatedCompositeContainer* self, + Py_ssize_t start, + Py_ssize_t limit, + PyObject* cmp, + PyObject* keyfunc) { + if (limit - start <= 1) + return 0; // Nothing to sort. + + GOOGLE_CHECK_ATTACHED(self); + + google::protobuf::Message* message = self->message; + const google::protobuf::Reflection* reflection = message->GetReflection(); + const google::protobuf::FieldDescriptor* descriptor = self->parent_field->descriptor; + Py_ssize_t left; + Py_ssize_t right; + + PyObject* children = self->child_messages; + + do { + left = start; + right = limit; + ScopedPyObjectPtr mid( + GET_KEY(keyfunc, PyList_GET_ITEM(children, (start + limit) / 2))); + do { + ScopedPyObjectPtr key(GET_KEY(keyfunc, PyList_GET_ITEM(children, left))); + int is_lt = islt(key, mid, cmp); + if (is_lt == -1) + return -1; + /* array[left]SwapElements(message, descriptor, left, right); + PyObject* tmp = PyList_GET_ITEM(children, left); + PyList_SET_ITEM(children, left, PyList_GET_ITEM(children, right)); + PyList_SET_ITEM(children, right, tmp); + } + ++left; + } + } while (left < right); + + if ((right - start) < (limit - left)) { + /* sort [start..right[ */ + if (start < (right - 1)) { + InternalQuickSort(self, start, right, cmp, keyfunc); + } + + /* sort [left..limit[ */ + start = left; + } else { + /* sort [left..limit[ */ + if (left < (limit - 1)) { + InternalQuickSort(self, left, limit, cmp, keyfunc); + } + + /* sort [start..right[ */ + limit = right; + } + } while (start < (limit - 1)); + + return 0; +} + +#undef GET_KEY + +// --------------------------------------------------------------------- +// len() + +static Py_ssize_t Length(RepeatedCompositeContainer* self) { + google::protobuf::Message* message = self->message; + if (message != NULL) { + return message->GetReflection()->FieldSize(*message, + self->parent_field->descriptor); + } else { + // The container has been released (i.e. by a call to Clear() or + // ClearField() on the parent) and thus there's no message. + return PyList_GET_SIZE(self->child_messages); + } +} + +// Returns 0 if successful; returns -1 and sets an exception if +// unsuccessful. +static int UpdateChildMessages(RepeatedCompositeContainer* self) { + if (self->message == NULL) + return 0; + + // A MergeFrom on a parent message could have caused extra messages to be + // added in the underlying protobuf so add them to our list. They can never + // be removed in such a way so there's no need to worry about that. + Py_ssize_t message_length = Length(self); + Py_ssize_t child_length = PyList_GET_SIZE(self->child_messages); + google::protobuf::Message* message = self->message; + const google::protobuf::Reflection* reflection = message->GetReflection(); + for (Py_ssize_t i = child_length; i < message_length; ++i) { + const Message& sub_message = reflection->GetRepeatedMessage( + *(self->message), self->parent_field->descriptor, i); + ScopedPyObjectPtr py_cmsg(cmessage::NewEmpty(self->subclass_init)); + if (py_cmsg == NULL) { + return -1; + } + CMessage* cmsg = reinterpret_cast(py_cmsg.get()); + cmsg->owner = self->owner; + cmsg->message = const_cast(&sub_message); + cmsg->parent = self->parent; + if (cmessage::InitAttributes(cmsg, NULL, NULL) < 0) { + return -1; + } + PyList_Append(self->child_messages, py_cmsg); + } + return 0; +} + +// --------------------------------------------------------------------- +// add() + +static PyObject* AddToAttached(RepeatedCompositeContainer* self, + PyObject* args, + PyObject* kwargs) { + GOOGLE_CHECK_ATTACHED(self); + + if (UpdateChildMessages(self) < 0) { + return NULL; + } + if (cmessage::AssureWritable(self->parent) == -1) + return NULL; + google::protobuf::Message* message = self->message; + google::protobuf::Message* sub_message = + message->GetReflection()->AddMessage(message, + self->parent_field->descriptor); + PyObject* py_cmsg = cmessage::NewEmpty(self->subclass_init); + if (py_cmsg == NULL) { + return NULL; + } + CMessage* cmsg = reinterpret_cast(py_cmsg); + + cmsg->owner = self->owner; + cmsg->message = sub_message; + cmsg->parent = self->parent; + // cmessage::InitAttributes must be called after cmsg->message has + // been set. + if (cmessage::InitAttributes(cmsg, NULL, kwargs) < 0) { + Py_DECREF(py_cmsg); + return NULL; + } + PyList_Append(self->child_messages, py_cmsg); + return py_cmsg; +} + +static PyObject* AddToReleased(RepeatedCompositeContainer* self, + PyObject* args, + PyObject* kwargs) { + GOOGLE_CHECK_RELEASED(self); + + // Create the CMessage + PyObject* py_cmsg = PyObject_CallObject(self->subclass_init, NULL); + if (py_cmsg == NULL) + return NULL; + CMessage* cmsg = reinterpret_cast(py_cmsg); + if (cmessage::InitAttributes(cmsg, NULL, kwargs) < 0) { + Py_DECREF(py_cmsg); + return NULL; + } + + // The Message got created by the call to subclass_init above and + // it set self->owner to the newly allocated message. + + PyList_Append(self->child_messages, py_cmsg); + return py_cmsg; +} + +PyObject* Add(RepeatedCompositeContainer* self, + PyObject* args, + PyObject* kwargs) { + if (self->message == NULL) + return AddToReleased(self, args, kwargs); + else + return AddToAttached(self, args, kwargs); +} + +// --------------------------------------------------------------------- +// extend() + +PyObject* Extend(RepeatedCompositeContainer* self, PyObject* value) { + cmessage::AssureWritable(self->parent); + if (UpdateChildMessages(self) < 0) { + return NULL; + } + ScopedPyObjectPtr iter(PyObject_GetIter(value)); + if (iter == NULL) { + PyErr_SetString(PyExc_TypeError, "Value must be iterable"); + return NULL; + } + ScopedPyObjectPtr next; + while ((next.reset(PyIter_Next(iter))) != NULL) { + if (!PyObject_TypeCheck(next, &CMessage_Type)) { + PyErr_SetString(PyExc_TypeError, "Not a cmessage"); + return NULL; + } + ScopedPyObjectPtr new_message(Add(self, NULL, NULL)); + if (new_message == NULL) { + return NULL; + } + CMessage* new_cmessage = reinterpret_cast(new_message.get()); + if (cmessage::MergeFrom(new_cmessage, next) == NULL) { + return NULL; + } + } + if (PyErr_Occurred()) { + return NULL; + } + Py_RETURN_NONE; +} + +PyObject* MergeFrom(RepeatedCompositeContainer* self, PyObject* other) { + if (UpdateChildMessages(self) < 0) { + return NULL; + } + return Extend(self, other); +} + +PyObject* Subscript(RepeatedCompositeContainer* self, PyObject* slice) { + if (UpdateChildMessages(self) < 0) { + return NULL; + } + Py_ssize_t from; + Py_ssize_t to; + Py_ssize_t step; + Py_ssize_t length = Length(self); + Py_ssize_t slicelength; + if (PySlice_Check(slice)) { +#if PY_MAJOR_VERSION >= 3 + if (PySlice_GetIndicesEx(slice, +#else + if (PySlice_GetIndicesEx(reinterpret_cast(slice), +#endif + length, &from, &to, &step, &slicelength) == -1) { + return NULL; + } + return PyList_GetSlice(self->child_messages, from, to); + } else if (PyInt_Check(slice) || PyLong_Check(slice)) { + from = to = PyLong_AsLong(slice); + if (from < 0) { + from = to = length + from; + } + PyObject* result = PyList_GetItem(self->child_messages, from); + if (result == NULL) { + return NULL; + } + Py_INCREF(result); + return result; + } + PyErr_SetString(PyExc_TypeError, "index must be an integer or slice"); + return NULL; +} + +int AssignSubscript(RepeatedCompositeContainer* self, + PyObject* slice, + PyObject* value) { + if (UpdateChildMessages(self) < 0) { + return -1; + } + if (value != NULL) { + PyErr_SetString(PyExc_TypeError, "does not support assignment"); + return -1; + } + + // Delete from the underlying Message, if any. + if (self->message != NULL) { + if (cmessage::InternalDeleteRepeatedField(self->message, + self->parent_field->descriptor, + slice, + self->child_messages) < 0) { + return -1; + } + } else { + Py_ssize_t from; + Py_ssize_t to; + Py_ssize_t step; + Py_ssize_t length = Length(self); + Py_ssize_t slicelength; + if (PySlice_Check(slice)) { +#if PY_MAJOR_VERSION >= 3 + if (PySlice_GetIndicesEx(slice, +#else + if (PySlice_GetIndicesEx(reinterpret_cast(slice), +#endif + length, &from, &to, &step, &slicelength) == -1) { + return -1; + } + return PySequence_DelSlice(self->child_messages, from, to); + } else if (PyInt_Check(slice) || PyLong_Check(slice)) { + from = to = PyLong_AsLong(slice); + if (from < 0) { + from = to = length + from; + } + return PySequence_DelItem(self->child_messages, from); + } + } + + return 0; +} + +static PyObject* Remove(RepeatedCompositeContainer* self, PyObject* value) { + if (UpdateChildMessages(self) < 0) { + return NULL; + } + Py_ssize_t index = PySequence_Index(self->child_messages, value); + if (index == -1) { + return NULL; + } + ScopedPyObjectPtr py_index(PyLong_FromLong(index)); + if (AssignSubscript(self, py_index, NULL) < 0) { + return NULL; + } + Py_RETURN_NONE; +} + +static PyObject* RichCompare(RepeatedCompositeContainer* self, + PyObject* other, + int opid) { + if (UpdateChildMessages(self) < 0) { + return NULL; + } + if (!PyObject_TypeCheck(other, &RepeatedCompositeContainer_Type)) { + PyErr_SetString(PyExc_TypeError, + "Can only compare repeated composite fields " + "against other repeated composite fields."); + return NULL; + } + if (opid == Py_EQ || opid == Py_NE) { + // TODO(anuraag): Don't make new lists just for this... + ScopedPyObjectPtr full_slice(PySlice_New(NULL, NULL, NULL)); + if (full_slice == NULL) { + return NULL; + } + ScopedPyObjectPtr list(Subscript(self, full_slice)); + if (list == NULL) { + return NULL; + } + ScopedPyObjectPtr other_list( + Subscript( + reinterpret_cast(other), full_slice)); + if (other_list == NULL) { + return NULL; + } + return PyObject_RichCompare(list, other_list, opid); + } else { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } +} + +// --------------------------------------------------------------------- +// sort() + +static PyObject* SortAttached(RepeatedCompositeContainer* self, + PyObject* args, + PyObject* kwds) { + // Sort the underlying Message array. + PyObject *compare = NULL; + int reverse = 0; + PyObject *keyfunc = NULL; + static char *kwlist[] = {"cmp", "key", "reverse", 0}; + + if (args != NULL) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOi:sort", + kwlist, &compare, &keyfunc, &reverse)) + return NULL; + } + if (compare == Py_None) + compare = NULL; + if (keyfunc == Py_None) + keyfunc = NULL; + + const Py_ssize_t length = Length(self); + if (InternalQuickSort(self, 0, length, compare, keyfunc) < 0) + return NULL; + + // Finally reverse the result if requested. + if (reverse) { + google::protobuf::Message* message = self->message; + const google::protobuf::Reflection* reflection = message->GetReflection(); + const google::protobuf::FieldDescriptor* descriptor = self->parent_field->descriptor; + + // Reverse the Message array. + for (int i = 0; i < length / 2; ++i) + reflection->SwapElements(message, descriptor, i, length - i - 1); + + // Reverse the Python list. + ScopedPyObjectPtr res(PyObject_CallMethod(self->child_messages, + "reverse", NULL)); + if (res == NULL) + return NULL; + } + + Py_RETURN_NONE; +} + +static PyObject* SortReleased(RepeatedCompositeContainer* self, + PyObject* args, + PyObject* kwds) { + ScopedPyObjectPtr m(PyObject_GetAttrString(self->child_messages, "sort")); + if (m == NULL) + return NULL; + if (PyObject_Call(m, args, kwds) == NULL) + return NULL; + Py_RETURN_NONE; +} + +static PyObject* Sort(RepeatedCompositeContainer* self, + PyObject* args, + PyObject* kwds) { + // Support the old sort_function argument for backwards + // compatibility. + if (kwds != NULL) { + PyObject* sort_func = PyDict_GetItemString(kwds, "sort_function"); + if (sort_func != NULL) { + // Must set before deleting as sort_func is a borrowed reference + // and kwds might be the only thing keeping it alive. + PyDict_SetItemString(kwds, "cmp", sort_func); + PyDict_DelItemString(kwds, "sort_function"); + } + } + + if (UpdateChildMessages(self) < 0) + return NULL; + if (self->message == NULL) { + return SortReleased(self, args, kwds); + } else { + return SortAttached(self, args, kwds); + } +} + +// --------------------------------------------------------------------- + +static PyObject* Item(RepeatedCompositeContainer* self, Py_ssize_t index) { + if (UpdateChildMessages(self) < 0) { + return NULL; + } + Py_ssize_t length = Length(self); + if (index < 0) { + index = length + index; + } + PyObject* item = PyList_GetItem(self->child_messages, index); + if (item == NULL) { + return NULL; + } + Py_INCREF(item); + return item; +} + +// The caller takes ownership of the returned Message. +Message* ReleaseLast(const FieldDescriptor* field, + const Descriptor* type, + Message* message) { + GOOGLE_CHECK_NOTNULL(field); + GOOGLE_CHECK_NOTNULL(type); + GOOGLE_CHECK_NOTNULL(message); + + Message* released_message = message->GetReflection()->ReleaseLast( + message, field); + // TODO(tibell): Deal with proto1. + + // ReleaseMessage will return NULL which differs from + // child_cmessage->message, if the field does not exist. In this case, + // the latter points to the default instance via a const_cast<>, so we + // have to reset it to a new mutable object since we are taking ownership. + if (released_message == NULL) { + const Message* prototype = global_message_factory->GetPrototype(type); + GOOGLE_CHECK_NOTNULL(prototype); + return prototype->New(); + } else { + return released_message; + } +} + +// Release field of message and transfer the ownership to cmessage. +void ReleaseLastTo(const FieldDescriptor* field, + Message* message, + CMessage* cmessage) { + GOOGLE_CHECK_NOTNULL(field); + GOOGLE_CHECK_NOTNULL(message); + GOOGLE_CHECK_NOTNULL(cmessage); + + shared_ptr released_message( + ReleaseLast(field, cmessage->message->GetDescriptor(), message)); + cmessage->parent = NULL; + cmessage->parent_field = NULL; + cmessage->message = released_message.get(); + cmessage->read_only = false; + cmessage::SetOwner(cmessage, released_message); +} + +// Called to release a container using +// ClearField('container_field_name') on the parent. +int Release(RepeatedCompositeContainer* self) { + if (UpdateChildMessages(self) < 0) { + PyErr_WriteUnraisable(PyBytes_FromString("Failed to update released " + "messages")); + return -1; + } + + Message* message = self->message; + const FieldDescriptor* field = self->parent_field->descriptor; + + // The reflection API only lets us release the last message in a + // repeated field. Therefore we iterate through the children + // starting with the last one. + const Py_ssize_t size = PyList_GET_SIZE(self->child_messages); + GOOGLE_DCHECK_EQ(size, message->GetReflection()->FieldSize(*message, field)); + for (Py_ssize_t i = size - 1; i >= 0; --i) { + CMessage* child_cmessage = reinterpret_cast( + PyList_GET_ITEM(self->child_messages, i)); + ReleaseLastTo(field, message, child_cmessage); + } + + // Detach from containing message. + self->parent = NULL; + self->parent_field = NULL; + self->message = NULL; + self->owner.reset(); + + return 0; +} + +int SetOwner(RepeatedCompositeContainer* self, + const shared_ptr& new_owner) { + GOOGLE_CHECK_ATTACHED(self); + + self->owner = new_owner; + const Py_ssize_t n = PyList_GET_SIZE(self->child_messages); + for (Py_ssize_t i = 0; i < n; ++i) { + PyObject* msg = PyList_GET_ITEM(self->child_messages, i); + if (cmessage::SetOwner(reinterpret_cast(msg), new_owner) == -1) { + return -1; + } + } + return 0; +} + +static int Init(RepeatedCompositeContainer* self, + PyObject* args, + PyObject* kwargs) { + self->message = NULL; + self->parent = NULL; + self->parent_field = NULL; + self->subclass_init = NULL; + self->child_messages = PyList_New(0); + return 0; +} + +static void Dealloc(RepeatedCompositeContainer* self) { + Py_CLEAR(self->child_messages); + // TODO(tibell): Do we need to call delete on these objects to make + // sure their destructors are called? + self->owner.reset(); + Py_TYPE(self)->tp_free(reinterpret_cast(self)); +} + +static PySequenceMethods SqMethods = { + (lenfunc)Length, /* sq_length */ + 0, /* sq_concat */ + 0, /* sq_repeat */ + (ssizeargfunc)Item /* sq_item */ +}; + +static PyMappingMethods MpMethods = { + (lenfunc)Length, /* mp_length */ + (binaryfunc)Subscript, /* mp_subscript */ + (objobjargproc)AssignSubscript,/* mp_ass_subscript */ +}; + +static PyMethodDef Methods[] = { + { "add", (PyCFunction) Add, METH_VARARGS | METH_KEYWORDS, + "Adds an object to the repeated container." }, + { "extend", (PyCFunction) Extend, METH_O, + "Adds objects to the repeated container." }, + { "remove", (PyCFunction) Remove, METH_O, + "Removes an object from the repeated container." }, + { "sort", (PyCFunction) Sort, METH_VARARGS | METH_KEYWORDS, + "Sorts the repeated container." }, + { "MergeFrom", (PyCFunction) MergeFrom, METH_O, + "Adds objects to the repeated container." }, + { NULL, NULL } +}; + +} // namespace repeated_composite_container + +PyTypeObject RepeatedCompositeContainer_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + "google.protobuf.internal." + "cpp._message.RepeatedCompositeContainer", // tp_name + sizeof(RepeatedCompositeContainer), // tp_basicsize + 0, // tp_itemsize + (destructor)repeated_composite_container::Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + &repeated_composite_container::SqMethods, // tp_as_sequence + &repeated_composite_container::MpMethods, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A Repeated scalar container", // tp_doc + 0, // tp_traverse + 0, // tp_clear + (richcmpfunc)repeated_composite_container::RichCompare, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + repeated_composite_container::Methods, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + (initproc)repeated_composite_container::Init, // tp_init +}; + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/python/google/protobuf/pyext/repeated_composite_container.h b/python/google/protobuf/pyext/repeated_composite_container.h new file mode 100644 index 00000000..e8ed30ed --- /dev/null +++ b/python/google/protobuf/pyext/repeated_composite_container.h @@ -0,0 +1,172 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: anuraag@google.com (Anuraag Agrawal) +// Author: tibell@google.com (Johan Tibell) + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_COMPOSITE_CONTAINER_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_COMPOSITE_CONTAINER_H__ + +#include + +#include +#ifndef _SHARED_PTR_H +#include +#endif +#include +#include + + +namespace google { +namespace protobuf { + +class FieldDescriptor; +class Message; + +using internal::shared_ptr; + +namespace python { + +struct CMessage; +struct CFieldDescriptor; + +// A RepeatedCompositeContainer can be in one of two states: attached +// or released. +// +// When in the attached state all modifications to the container are +// done both on the 'message' and on the 'child_messages' +// list. In this state all Messages refered to by the children in +// 'child_messages' are owner by the 'owner'. +// +// When in the released state 'message', 'owner', 'parent', and +// 'parent_field' are NULL. +typedef struct RepeatedCompositeContainer { + PyObject_HEAD; + + // This is the top-level C++ Message object that owns the whole + // proto tree. Every Python RepeatedCompositeContainer holds a + // reference to it in order to keep it alive as long as there's a + // Python object that references any part of the tree. + shared_ptr owner; + + // Weak reference to parent object. May be NULL. Used to make sure + // the parent is writable before modifying the + // RepeatedCompositeContainer. + CMessage* parent; + + // A descriptor used to modify the underlying 'message'. + CFieldDescriptor* parent_field; + + // Pointer to the C++ Message that contains this container. The + // RepeatedCompositeContainer does not own this pointer. + // + // If NULL, this message has been released from its parent (by + // calling Clear() or ClearField() on the parent. + Message* message; + + // A callable that is used to create new child messages. + PyObject* subclass_init; + + // A list of child messages. + PyObject* child_messages; +} RepeatedCompositeContainer; + +extern PyTypeObject RepeatedCompositeContainer_Type; + +namespace repeated_composite_container { + +// Returns the number of items in this repeated composite container. +static Py_ssize_t Length(RepeatedCompositeContainer* self); + +// Appends a new CMessage to the container and returns it. The +// CMessage is initialized using the content of kwargs. +// +// Returns a new reference if successful; returns NULL and sets an +// exception if unsuccessful. +PyObject* Add(RepeatedCompositeContainer* self, + PyObject* args, + PyObject* kwargs); + +// Appends all the CMessages in the input iterator to the container. +// +// Returns None if successful; returns NULL and sets an exception if +// unsuccessful. +PyObject* Extend(RepeatedCompositeContainer* self, PyObject* value); + +// Appends a new message to the container for each message in the +// input iterator, merging each data element in. Equivalent to extend. +// +// Returns None if successful; returns NULL and sets an exception if +// unsuccessful. +PyObject* MergeFrom(RepeatedCompositeContainer* self, PyObject* other); + +// Accesses messages in the container. +// +// Returns a new reference to the message for an integer parameter. +// Returns a new reference to a list of messages for a slice. +PyObject* Subscript(RepeatedCompositeContainer* self, PyObject* slice); + +// Deletes items from the container (cannot be used for assignment). +// +// Returns 0 on success, -1 on failure. +int AssignSubscript(RepeatedCompositeContainer* self, + PyObject* slice, + PyObject* value); + +// Releases the messages in the container to the given message. +// +// Returns 0 on success, -1 on failure. +int ReleaseToMessage(RepeatedCompositeContainer* self, + google::protobuf::Message* new_message); + +// Releases the messages in the container to a new message. +// +// Returns 0 on success, -1 on failure. +int Release(RepeatedCompositeContainer* self); + +// Returns 0 on success, -1 on failure. +int SetOwner(RepeatedCompositeContainer* self, + const shared_ptr& new_owner); + +// Removes the last element of the repeated message field 'field' on +// the Message 'message', and transfers the ownership of the released +// Message to 'cmessage'. +// +// Corresponds to reflection api method ReleaseMessage. +void ReleaseLastTo(const FieldDescriptor* field, + Message* message, + CMessage* cmessage); + +} // namespace repeated_composite_container +} // namespace python +} // namespace protobuf + +} // namespace google +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_COMPOSITE_CONTAINER_H__ diff --git a/python/google/protobuf/pyext/repeated_scalar_container.cc b/python/google/protobuf/pyext/repeated_scalar_container.cc new file mode 100644 index 00000000..b0fcd816 --- /dev/null +++ b/python/google/protobuf/pyext/repeated_scalar_container.cc @@ -0,0 +1,825 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: anuraag@google.com (Anuraag Agrawal) +// Author: tibell@google.com (Johan Tibell) + +#include + +#include +#ifndef _SHARED_PTR_H +#include +#endif + +#include +#include +#include +#include +#include +#include +#include + +#if PY_MAJOR_VERSION >= 3 + #define PyInt_FromLong PyLong_FromLong + #if PY_VERSION_HEX < 0x03030000 + #error "Python 3.0 - 3.2 are not supported." + #else + #define PyString_AsString(ob) \ + (PyUnicode_Check(ob)? PyUnicode_AsUTF8(ob): PyBytes_AS_STRING(ob)) + #endif +#endif + +namespace google { +namespace protobuf { +namespace python { + +extern google::protobuf::DynamicMessageFactory* global_message_factory; + +namespace repeated_scalar_container { + +static int InternalAssignRepeatedField( + RepeatedScalarContainer* self, PyObject* list) { + self->message->GetReflection()->ClearField(self->message, + self->parent_field->descriptor); + for (Py_ssize_t i = 0; i < PyList_GET_SIZE(list); ++i) { + PyObject* value = PyList_GET_ITEM(list, i); + if (Append(self, value) == NULL) { + return -1; + } + } + return 0; +} + +static Py_ssize_t Len(RepeatedScalarContainer* self) { + google::protobuf::Message* message = self->message; + return message->GetReflection()->FieldSize(*message, + self->parent_field->descriptor); +} + +static int AssignItem(RepeatedScalarContainer* self, + Py_ssize_t index, + PyObject* arg) { + cmessage::AssureWritable(self->parent); + google::protobuf::Message* message = self->message; + const google::protobuf::FieldDescriptor* field_descriptor = + self->parent_field->descriptor; + if (!FIELD_BELONGS_TO_MESSAGE(field_descriptor, message)) { + PyErr_SetString( + PyExc_KeyError, "Field does not belong to message!"); + return -1; + } + + const google::protobuf::Reflection* reflection = message->GetReflection(); + int field_size = reflection->FieldSize(*message, field_descriptor); + if (index < 0) { + index = field_size + index; + } + if (index < 0 || index >= field_size) { + PyErr_Format(PyExc_IndexError, + "list assignment index (%d) out of range", + static_cast(index)); + return -1; + } + + if (arg == NULL) { + ScopedPyObjectPtr py_index(PyLong_FromLong(index)); + return cmessage::InternalDeleteRepeatedField(message, field_descriptor, + py_index, NULL); + } + + if (PySequence_Check(arg) && !(PyBytes_Check(arg) || PyUnicode_Check(arg))) { + PyErr_SetString(PyExc_TypeError, "Value must be scalar"); + return -1; + } + + switch (field_descriptor->cpp_type()) { + case google::protobuf::FieldDescriptor::CPPTYPE_INT32: { + GOOGLE_CHECK_GET_INT32(arg, value, -1); + reflection->SetRepeatedInt32(message, field_descriptor, index, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_INT64: { + GOOGLE_CHECK_GET_INT64(arg, value, -1); + reflection->SetRepeatedInt64(message, field_descriptor, index, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: { + GOOGLE_CHECK_GET_UINT32(arg, value, -1); + reflection->SetRepeatedUInt32(message, field_descriptor, index, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: { + GOOGLE_CHECK_GET_UINT64(arg, value, -1); + reflection->SetRepeatedUInt64(message, field_descriptor, index, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: { + GOOGLE_CHECK_GET_FLOAT(arg, value, -1); + reflection->SetRepeatedFloat(message, field_descriptor, index, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: { + GOOGLE_CHECK_GET_DOUBLE(arg, value, -1); + reflection->SetRepeatedDouble(message, field_descriptor, index, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: { + GOOGLE_CHECK_GET_BOOL(arg, value, -1); + reflection->SetRepeatedBool(message, field_descriptor, index, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { + if (!CheckAndSetString( + arg, message, field_descriptor, reflection, false, index)) { + return -1; + } + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { + GOOGLE_CHECK_GET_INT32(arg, value, -1); + const google::protobuf::EnumDescriptor* enum_descriptor = + field_descriptor->enum_type(); + const google::protobuf::EnumValueDescriptor* enum_value = + enum_descriptor->FindValueByNumber(value); + if (enum_value != NULL) { + reflection->SetRepeatedEnum(message, field_descriptor, index, + enum_value); + } else { + ScopedPyObjectPtr s(PyObject_Str(arg)); + if (s != NULL) { + PyErr_Format(PyExc_ValueError, "Unknown enum value: %s", + PyString_AsString(s.get())); + } + return -1; + } + break; + } + default: + PyErr_Format( + PyExc_SystemError, "Adding value to a field of unknown type %d", + field_descriptor->cpp_type()); + return -1; + } + return 0; +} + +static PyObject* Item(RepeatedScalarContainer* self, Py_ssize_t index) { + google::protobuf::Message* message = self->message; + const google::protobuf::FieldDescriptor* field_descriptor = + self->parent_field->descriptor; + const google::protobuf::Reflection* reflection = message->GetReflection(); + + int field_size = reflection->FieldSize(*message, field_descriptor); + if (index < 0) { + index = field_size + index; + } + if (index < 0 || index >= field_size) { + PyErr_Format(PyExc_IndexError, + "list assignment index (%d) out of range", + static_cast(index)); + return NULL; + } + + PyObject* result = NULL; + switch (field_descriptor->cpp_type()) { + case google::protobuf::FieldDescriptor::CPPTYPE_INT32: { + int32 value = reflection->GetRepeatedInt32( + *message, field_descriptor, index); + result = PyInt_FromLong(value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_INT64: { + int64 value = reflection->GetRepeatedInt64( + *message, field_descriptor, index); + result = PyLong_FromLongLong(value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: { + uint32 value = reflection->GetRepeatedUInt32( + *message, field_descriptor, index); + result = PyLong_FromLongLong(value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: { + uint64 value = reflection->GetRepeatedUInt64( + *message, field_descriptor, index); + result = PyLong_FromUnsignedLongLong(value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: { + float value = reflection->GetRepeatedFloat( + *message, field_descriptor, index); + result = PyFloat_FromDouble(value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: { + double value = reflection->GetRepeatedDouble( + *message, field_descriptor, index); + result = PyFloat_FromDouble(value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: { + bool value = reflection->GetRepeatedBool( + *message, field_descriptor, index); + result = PyBool_FromLong(value ? 1 : 0); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { + const google::protobuf::EnumValueDescriptor* enum_value = + message->GetReflection()->GetRepeatedEnum( + *message, field_descriptor, index); + result = PyInt_FromLong(enum_value->number()); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { + string value = reflection->GetRepeatedString( + *message, field_descriptor, index); + result = ToStringObject(field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { + PyObject* py_cmsg = PyObject_CallObject(reinterpret_cast( + &CMessage_Type), NULL); + if (py_cmsg == NULL) { + return NULL; + } + CMessage* cmsg = reinterpret_cast(py_cmsg); + const google::protobuf::Message& msg = reflection->GetRepeatedMessage( + *message, field_descriptor, index); + cmsg->owner = self->owner; + cmsg->parent = self->parent; + cmsg->message = const_cast(&msg); + cmsg->read_only = false; + result = reinterpret_cast(py_cmsg); + break; + } + default: + PyErr_Format( + PyExc_SystemError, + "Getting value from a repeated field of unknown type %d", + field_descriptor->cpp_type()); + } + + return result; +} + +static PyObject* Subscript(RepeatedScalarContainer* self, PyObject* slice) { + Py_ssize_t from; + Py_ssize_t to; + Py_ssize_t step; + Py_ssize_t length; + Py_ssize_t slicelength; + bool return_list = false; +#if PY_MAJOR_VERSION < 3 + if (PyInt_Check(slice)) { + from = to = PyInt_AsLong(slice); + } else // NOLINT +#endif + if (PyLong_Check(slice)) { + from = to = PyLong_AsLong(slice); + } else if (PySlice_Check(slice)) { + length = Len(self); +#if PY_MAJOR_VERSION >= 3 + if (PySlice_GetIndicesEx(slice, +#else + if (PySlice_GetIndicesEx(reinterpret_cast(slice), +#endif + length, &from, &to, &step, &slicelength) == -1) { + return NULL; + } + return_list = true; + } else { + PyErr_SetString(PyExc_TypeError, "list indices must be integers"); + return NULL; + } + + if (!return_list) { + return Item(self, from); + } + + PyObject* list = PyList_New(0); + if (list == NULL) { + return NULL; + } + if (from <= to) { + if (step < 0) { + return list; + } + for (Py_ssize_t index = from; index < to; index += step) { + if (index < 0 || index >= length) { + break; + } + ScopedPyObjectPtr s(Item(self, index)); + PyList_Append(list, s); + } + } else { + if (step > 0) { + return list; + } + for (Py_ssize_t index = from; index > to; index += step) { + if (index < 0 || index >= length) { + break; + } + ScopedPyObjectPtr s(Item(self, index)); + PyList_Append(list, s); + } + } + return list; +} + +PyObject* Append(RepeatedScalarContainer* self, PyObject* item) { + cmessage::AssureWritable(self->parent); + google::protobuf::Message* message = self->message; + const google::protobuf::FieldDescriptor* field_descriptor = + self->parent_field->descriptor; + + if (!FIELD_BELONGS_TO_MESSAGE(field_descriptor, message)) { + PyErr_SetString( + PyExc_KeyError, "Field does not belong to message!"); + return NULL; + } + + const google::protobuf::Reflection* reflection = message->GetReflection(); + switch (field_descriptor->cpp_type()) { + case google::protobuf::FieldDescriptor::CPPTYPE_INT32: { + GOOGLE_CHECK_GET_INT32(item, value, NULL); + reflection->AddInt32(message, field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_INT64: { + GOOGLE_CHECK_GET_INT64(item, value, NULL); + reflection->AddInt64(message, field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: { + GOOGLE_CHECK_GET_UINT32(item, value, NULL); + reflection->AddUInt32(message, field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: { + GOOGLE_CHECK_GET_UINT64(item, value, NULL); + reflection->AddUInt64(message, field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: { + GOOGLE_CHECK_GET_FLOAT(item, value, NULL); + reflection->AddFloat(message, field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: { + GOOGLE_CHECK_GET_DOUBLE(item, value, NULL); + reflection->AddDouble(message, field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: { + GOOGLE_CHECK_GET_BOOL(item, value, NULL); + reflection->AddBool(message, field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { + if (!CheckAndSetString( + item, message, field_descriptor, reflection, true, -1)) { + return NULL; + } + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { + GOOGLE_CHECK_GET_INT32(item, value, NULL); + const google::protobuf::EnumDescriptor* enum_descriptor = + field_descriptor->enum_type(); + const google::protobuf::EnumValueDescriptor* enum_value = + enum_descriptor->FindValueByNumber(value); + if (enum_value != NULL) { + reflection->AddEnum(message, field_descriptor, enum_value); + } else { + ScopedPyObjectPtr s(PyObject_Str(item)); + if (s != NULL) { + PyErr_Format(PyExc_ValueError, "Unknown enum value: %s", + PyString_AsString(s.get())); + } + return NULL; + } + break; + } + default: + PyErr_Format( + PyExc_SystemError, "Adding value to a field of unknown type %d", + field_descriptor->cpp_type()); + return NULL; + } + + Py_RETURN_NONE; +} + +static int AssSubscript(RepeatedScalarContainer* self, + PyObject* slice, + PyObject* value) { + Py_ssize_t from; + Py_ssize_t to; + Py_ssize_t step; + Py_ssize_t length; + Py_ssize_t slicelength; + bool create_list = false; + + cmessage::AssureWritable(self->parent); + google::protobuf::Message* message = self->message; + const google::protobuf::FieldDescriptor* field_descriptor = + self->parent_field->descriptor; + +#if PY_MAJOR_VERSION < 3 + if (PyInt_Check(slice)) { + from = to = PyInt_AsLong(slice); + } else +#endif + if (PyLong_Check(slice)) { + from = to = PyLong_AsLong(slice); + } else if (PySlice_Check(slice)) { + const google::protobuf::Reflection* reflection = message->GetReflection(); + length = reflection->FieldSize(*message, field_descriptor); +#if PY_MAJOR_VERSION >= 3 + if (PySlice_GetIndicesEx(slice, +#else + if (PySlice_GetIndicesEx(reinterpret_cast(slice), +#endif + length, &from, &to, &step, &slicelength) == -1) { + return -1; + } + create_list = true; + } else { + PyErr_SetString(PyExc_TypeError, "list indices must be integers"); + return -1; + } + + if (value == NULL) { + return cmessage::InternalDeleteRepeatedField( + message, field_descriptor, slice, NULL); + } + + if (!create_list) { + return AssignItem(self, from, value); + } + + ScopedPyObjectPtr full_slice(PySlice_New(NULL, NULL, NULL)); + if (full_slice == NULL) { + return -1; + } + ScopedPyObjectPtr new_list(Subscript(self, full_slice)); + if (new_list == NULL) { + return -1; + } + if (PySequence_SetSlice(new_list, from, to, value) < 0) { + return -1; + } + + return InternalAssignRepeatedField(self, new_list); +} + +PyObject* Extend(RepeatedScalarContainer* self, PyObject* value) { + cmessage::AssureWritable(self->parent); + if (PyObject_Not(value)) { + Py_RETURN_NONE; + } + ScopedPyObjectPtr iter(PyObject_GetIter(value)); + if (iter == NULL) { + PyErr_SetString(PyExc_TypeError, "Value must be iterable"); + return NULL; + } + ScopedPyObjectPtr next; + while ((next.reset(PyIter_Next(iter))) != NULL) { + if (Append(self, next) == NULL) { + return NULL; + } + } + if (PyErr_Occurred()) { + return NULL; + } + Py_RETURN_NONE; +} + +static PyObject* Insert(RepeatedScalarContainer* self, PyObject* args) { + Py_ssize_t index; + PyObject* value; + if (!PyArg_ParseTuple(args, "lO", &index, &value)) { + return NULL; + } + ScopedPyObjectPtr full_slice(PySlice_New(NULL, NULL, NULL)); + ScopedPyObjectPtr new_list(Subscript(self, full_slice)); + if (PyList_Insert(new_list, index, value) < 0) { + return NULL; + } + int ret = InternalAssignRepeatedField(self, new_list); + if (ret < 0) { + return NULL; + } + Py_RETURN_NONE; +} + +static PyObject* Remove(RepeatedScalarContainer* self, PyObject* value) { + Py_ssize_t match_index = -1; + for (Py_ssize_t i = 0; i < Len(self); ++i) { + ScopedPyObjectPtr elem(Item(self, i)); + if (PyObject_RichCompareBool(elem, value, Py_EQ)) { + match_index = i; + break; + } + } + if (match_index == -1) { + PyErr_SetString(PyExc_ValueError, "remove(x): x not in container"); + return NULL; + } + if (AssignItem(self, match_index, NULL) < 0) { + return NULL; + } + Py_RETURN_NONE; +} + +static PyObject* RichCompare(RepeatedScalarContainer* self, + PyObject* other, + int opid) { + if (opid != Py_EQ && opid != Py_NE) { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } + + // Copy the contents of this repeated scalar container, and other if it is + // also a repeated scalar container, into Python lists so we can delegate + // to the list's compare method. + + ScopedPyObjectPtr full_slice(PySlice_New(NULL, NULL, NULL)); + if (full_slice == NULL) { + return NULL; + } + + ScopedPyObjectPtr other_list_deleter; + if (PyObject_TypeCheck(other, &RepeatedScalarContainer_Type)) { + other_list_deleter.reset(Subscript( + reinterpret_cast(other), full_slice)); + other = other_list_deleter.get(); + } + + ScopedPyObjectPtr list(Subscript(self, full_slice)); + if (list == NULL) { + return NULL; + } + return PyObject_RichCompare(list, other, opid); +} + +PyObject* Reduce(RepeatedScalarContainer* unused_self) { + PyErr_Format( + PickleError_class, + "can't pickle repeated message fields, convert to list first"); + return NULL; +} + +static PyObject* Sort(RepeatedScalarContainer* self, + PyObject* args, + PyObject* kwds) { + // Support the old sort_function argument for backwards + // compatibility. + if (kwds != NULL) { + PyObject* sort_func = PyDict_GetItemString(kwds, "sort_function"); + if (sort_func != NULL) { + // Must set before deleting as sort_func is a borrowed reference + // and kwds might be the only thing keeping it alive. + if (PyDict_SetItemString(kwds, "cmp", sort_func) == -1) + return NULL; + if (PyDict_DelItemString(kwds, "sort_function") == -1) + return NULL; + } + } + + ScopedPyObjectPtr full_slice(PySlice_New(NULL, NULL, NULL)); + if (full_slice == NULL) { + return NULL; + } + ScopedPyObjectPtr list(Subscript(self, full_slice)); + if (list == NULL) { + return NULL; + } + ScopedPyObjectPtr m(PyObject_GetAttrString(list, "sort")); + if (m == NULL) { + return NULL; + } + ScopedPyObjectPtr res(PyObject_Call(m, args, kwds)); + if (res == NULL) { + return NULL; + } + int ret = InternalAssignRepeatedField(self, list); + if (ret < 0) { + return NULL; + } + Py_RETURN_NONE; +} + +static int Init(RepeatedScalarContainer* self, + PyObject* args, + PyObject* kwargs) { + PyObject* py_parent; + PyObject* py_parent_field; + if (!PyArg_UnpackTuple(args, "__init__()", 2, 2, &py_parent, + &py_parent_field)) { + return -1; + } + + if (!PyObject_TypeCheck(py_parent, &CMessage_Type)) { + PyErr_Format(PyExc_TypeError, + "expect %s, but got %s", + CMessage_Type.tp_name, + Py_TYPE(py_parent)->tp_name); + return -1; + } + + if (!PyObject_TypeCheck(py_parent_field, &CFieldDescriptor_Type)) { + PyErr_Format(PyExc_TypeError, + "expect %s, but got %s", + CFieldDescriptor_Type.tp_name, + Py_TYPE(py_parent_field)->tp_name); + return -1; + } + + CMessage* cmessage = reinterpret_cast(py_parent); + CFieldDescriptor* cdescriptor = reinterpret_cast( + py_parent_field); + + if (!FIELD_BELONGS_TO_MESSAGE(cdescriptor->descriptor, cmessage->message)) { + PyErr_SetString( + PyExc_KeyError, "Field does not belong to message!"); + return -1; + } + + self->message = cmessage->message; + self->parent = cmessage; + self->parent_field = cdescriptor; + self->owner = cmessage->owner; + return 0; +} + +// Initializes the underlying Message object of "to" so it becomes a new parent +// repeated scalar, and copies all the values from "from" to it. A child scalar +// container can be released by passing it as both from and to (e.g. making it +// the recipient of the new parent message and copying the values from itself). +static int InitializeAndCopyToParentContainer( + RepeatedScalarContainer* from, + RepeatedScalarContainer* to) { + ScopedPyObjectPtr full_slice(PySlice_New(NULL, NULL, NULL)); + if (full_slice == NULL) { + return -1; + } + ScopedPyObjectPtr values(Subscript(from, full_slice)); + if (values == NULL) { + return -1; + } + google::protobuf::Message* new_message = global_message_factory->GetPrototype( + from->message->GetDescriptor())->New(); + to->parent = NULL; + // TODO(anuraag): Document why it's OK to hang on to parent_field, + // even though it's a weak reference. It ought to be enough to + // hold on to the FieldDescriptor only. + to->parent_field = from->parent_field; + to->message = new_message; + to->owner.reset(new_message); + if (InternalAssignRepeatedField(to, values) < 0) { + return -1; + } + return 0; +} + +int Release(RepeatedScalarContainer* self) { + return InitializeAndCopyToParentContainer(self, self); +} + +PyObject* DeepCopy(RepeatedScalarContainer* self, PyObject* arg) { + ScopedPyObjectPtr init_args( + PyTuple_Pack(2, self->parent, self->parent_field)); + PyObject* clone = PyObject_CallObject( + reinterpret_cast(&RepeatedScalarContainer_Type), init_args); + if (clone == NULL) { + return NULL; + } + if (!PyObject_TypeCheck(clone, &RepeatedScalarContainer_Type)) { + Py_DECREF(clone); + return NULL; + } + if (InitializeAndCopyToParentContainer( + self, reinterpret_cast(clone)) < 0) { + Py_DECREF(clone); + return NULL; + } + return clone; +} + +static void Dealloc(RepeatedScalarContainer* self) { + self->owner.reset(); + Py_TYPE(self)->tp_free(reinterpret_cast(self)); +} + +void SetOwner(RepeatedScalarContainer* self, + const shared_ptr& new_owner) { + self->owner = new_owner; +} + +static PySequenceMethods SqMethods = { + (lenfunc)Len, /* sq_length */ + 0, /* sq_concat */ + 0, /* sq_repeat */ + (ssizeargfunc)Item, /* sq_item */ + 0, /* sq_slice */ + (ssizeobjargproc)AssignItem /* sq_ass_item */ +}; + +static PyMappingMethods MpMethods = { + (lenfunc)Len, /* mp_length */ + (binaryfunc)Subscript, /* mp_subscript */ + (objobjargproc)AssSubscript, /* mp_ass_subscript */ +}; + +static PyMethodDef Methods[] = { + { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS, + "Makes a deep copy of the class." }, + { "__reduce__", (PyCFunction)Reduce, METH_NOARGS, + "Outputs picklable representation of the repeated field." }, + { "append", (PyCFunction)Append, METH_O, + "Appends an object to the repeated container." }, + { "extend", (PyCFunction)Extend, METH_O, + "Appends objects to the repeated container." }, + { "insert", (PyCFunction)Insert, METH_VARARGS, + "Appends objects to the repeated container." }, + { "remove", (PyCFunction)Remove, METH_O, + "Removes an object from the repeated container." }, + { "sort", (PyCFunction)Sort, METH_VARARGS | METH_KEYWORDS, + "Sorts the repeated container."}, + { NULL, NULL } +}; + +} // namespace repeated_scalar_container + +PyTypeObject RepeatedScalarContainer_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + "google.protobuf.internal." + "cpp._message.RepeatedScalarContainer", // tp_name + sizeof(RepeatedScalarContainer), // tp_basicsize + 0, // tp_itemsize + (destructor)repeated_scalar_container::Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + &repeated_scalar_container::SqMethods, // tp_as_sequence + &repeated_scalar_container::MpMethods, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A Repeated scalar container", // tp_doc + 0, // tp_traverse + 0, // tp_clear + (richcmpfunc)repeated_scalar_container::RichCompare, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + repeated_scalar_container::Methods, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + (initproc)repeated_scalar_container::Init, // tp_init +}; + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/python/google/protobuf/pyext/repeated_scalar_container.h b/python/google/protobuf/pyext/repeated_scalar_container.h new file mode 100644 index 00000000..8a301385 --- /dev/null +++ b/python/google/protobuf/pyext/repeated_scalar_container.h @@ -0,0 +1,112 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: anuraag@google.com (Anuraag Agrawal) +// Author: tibell@google.com (Johan Tibell) + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_SCALAR_CONTAINER_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_SCALAR_CONTAINER_H__ + +#include + +#include +#ifndef _SHARED_PTR_H +#include +#endif + + +namespace google { +namespace protobuf { + +class Message; + +using internal::shared_ptr; + +namespace python { + +struct CFieldDescriptor; +struct CMessage; + +typedef struct RepeatedScalarContainer { + PyObject_HEAD; + + // This is the top-level C++ Message object that owns the whole + // proto tree. Every Python RepeatedScalarContainer holds a + // reference to it in order to keep it alive as long as there's a + // Python object that references any part of the tree. + shared_ptr owner; + + // Pointer to the C++ Message that contains this container. The + // RepeatedScalarContainer does not own this pointer. + Message* message; + + // Weak reference to a parent CMessage object (i.e. may be NULL.) + // + // Used to make sure all ancestors are also mutable when first + // modifying the container. + CMessage* parent; + + // Weak reference to the parent's descriptor that describes this + // field. Used together with the parent's message when making a + // default message instance mutable. + CFieldDescriptor* parent_field; +} RepeatedScalarContainer; + +extern PyTypeObject RepeatedScalarContainer_Type; + +namespace repeated_scalar_container { + +// Appends the scalar 'item' to the end of the container 'self'. +// +// Returns None if successful; returns NULL and sets an exception if +// unsuccessful. +PyObject* Append(RepeatedScalarContainer* self, PyObject* item); + +// Releases the messages in the container to a new message. +// +// Returns 0 on success, -1 on failure. +int Release(RepeatedScalarContainer* self); + +// Appends all the elements in the input iterator to the container. +// +// Returns None if successful; returns NULL and sets an exception if +// unsuccessful. +PyObject* Extend(RepeatedScalarContainer* self, PyObject* value); + +// Set the owner field of self and any children of self. +void SetOwner(RepeatedScalarContainer* self, + const shared_ptr& new_owner); + +} // namespace repeated_scalar_container +} // namespace python +} // namespace protobuf + +} // namespace google +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_SCALAR_CONTAINER_H__ diff --git a/python/google/protobuf/pyext/scoped_pyobject_ptr.h b/python/google/protobuf/pyext/scoped_pyobject_ptr.h new file mode 100644 index 00000000..1b27a894 --- /dev/null +++ b/python/google/protobuf/pyext/scoped_pyobject_ptr.h @@ -0,0 +1,95 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: tibell@google.com (Johan Tibell) + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_SCOPED_PYOBJECT_PTR_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_SCOPED_PYOBJECT_PTR_H__ + +#include + +namespace google { +class ScopedPyObjectPtr { + public: + // Constructor. Defaults to intializing with NULL. + // There is no way to create an uninitialized ScopedPyObjectPtr. + explicit ScopedPyObjectPtr(PyObject* p = NULL) : ptr_(p) { } + + // Destructor. If there is a PyObject object, delete it. + ~ScopedPyObjectPtr() { + Py_XDECREF(ptr_); + } + + // Reset. Deletes the current owned object, if any. + // Then takes ownership of a new object, if given. + // this->reset(this->get()) works. + PyObject* reset(PyObject* p = NULL) { + if (p != ptr_) { + Py_XDECREF(ptr_); + ptr_ = p; + } + return ptr_; + } + + // Releases ownership of the object. + PyObject* release() { + PyObject* p = ptr_; + ptr_ = NULL; + return p; + } + + operator PyObject*() { return ptr_; } + + PyObject* operator->() const { + assert(ptr_ != NULL); + return ptr_; + } + + PyObject* get() const { return ptr_; } + + Py_ssize_t refcnt() const { return Py_REFCNT(ptr_); } + + void inc() const { Py_INCREF(ptr_); } + + // Comparison operators. + // These return whether a ScopedPyObjectPtr and a raw pointer + // refer to the same object, not just to two different but equal + // objects. + bool operator==(const PyObject* p) const { return ptr_ == p; } + bool operator!=(const PyObject* p) const { return ptr_ != p; } + + private: + PyObject* ptr_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ScopedPyObjectPtr); +}; + +} // namespace google +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_SCOPED_PYOBJECT_PTR_H__ diff --git a/python/google/protobuf/reflection.py b/python/google/protobuf/reflection.py index 9570fd50..7aac6230 100755 --- a/python/google/protobuf/reflection.py +++ b/python/google/protobuf/reflection.py @@ -57,7 +57,7 @@ _FieldDescriptor = descriptor_mod.FieldDescriptor if api_implementation.Type() == 'cpp': if api_implementation.Version() == 2: - from google.protobuf.internal.cpp import cpp_message + from google.protobuf.pyext import cpp_message _NewMessage = cpp_message.NewMessage _InitMessage = cpp_message.InitMessage else: @@ -91,6 +91,10 @@ class GeneratedProtocolMessageType(type): myproto_instance = MyProtoClass() myproto.foo_field = 23 ... + + The above example will not work for nested types. If you wish to include them, + use reflection.MakeClass() instead of manually instantiating the class in + order to create the appropriate class structure. """ # Must be consistent with the protocol-compiler code in @@ -159,11 +163,43 @@ def ParseMessage(descriptor, byte_str): Returns: Newly created protobuf Message object. """ + result_class = MakeClass(descriptor) + new_msg = result_class() + new_msg.ParseFromString(byte_str) + return new_msg + - class _ResultClass(message.Message): +def MakeClass(descriptor): + """Construct a class object for a protobuf described by descriptor. + + Composite descriptors are handled by defining the new class as a member of the + parent class, recursing as deep as necessary. + This is the dynamic equivalent to: + + class Parent(message.Message): __metaclass__ = GeneratedProtocolMessageType DESCRIPTOR = descriptor + class Child(message.Message): + __metaclass__ = GeneratedProtocolMessageType + DESCRIPTOR = descriptor.nested_types[0] - new_msg = _ResultClass() - new_msg.ParseFromString(byte_str) - return new_msg + Sample usage: + file_descriptor = descriptor_pb2.FileDescriptorProto() + file_descriptor.ParseFromString(proto2_string) + msg_descriptor = descriptor.MakeDescriptor(file_descriptor.message_type[0]) + msg_class = reflection.MakeClass(msg_descriptor) + msg = msg_class() + + Args: + descriptor: A descriptor.Descriptor object describing the protobuf. + Returns: + The Message class object described by the descriptor. + """ + attributes = {} + for name, nested_type in descriptor.nested_types_by_name.items(): + attributes[name] = MakeClass(nested_type) + + attributes[GeneratedProtocolMessageType._DESCRIPTOR_KEY] = descriptor + + return GeneratedProtocolMessageType(str(descriptor.name), (message.Message,), + attributes) diff --git a/python/google/protobuf/symbol_database.py b/python/google/protobuf/symbol_database.py new file mode 100644 index 00000000..7466fec5 --- /dev/null +++ b/python/google/protobuf/symbol_database.py @@ -0,0 +1,185 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""A database of Python protocol buffer generated symbols. + +SymbolDatabase makes it easy to create new instances of a registered type, given +only the type's protocol buffer symbol name. Once all symbols are registered, +they can be accessed using either the MessageFactory interface which +SymbolDatabase exposes, or the DescriptorPool interface of the underlying +pool. + +Example usage: + + db = symbol_database.SymbolDatabase() + + # Register symbols of interest, from one or multiple files. + db.RegisterFileDescriptor(my_proto_pb2.DESCRIPTOR) + db.RegisterMessage(my_proto_pb2.MyMessage) + db.RegisterEnumDescriptor(my_proto_pb2.MyEnum.DESCRIPTOR) + + # The database can be used as a MessageFactory, to generate types based on + # their name: + types = db.GetMessages(['my_proto.proto']) + my_message_instance = types['MyMessage']() + + # The database's underlying descriptor pool can be queried, so it's not + # necessary to know a type's filename to be able to generate it: + filename = db.pool.FindFileContainingSymbol('MyMessage') + my_message_instance = db.GetMessages([filename])['MyMessage']() + + # This functionality is also provided directly via a convenience method: + my_message_instance = db.GetSymbol('MyMessage')() +""" + + +from google.protobuf import descriptor_pool + + +class SymbolDatabase(object): + """A database of Python generated symbols. + + SymbolDatabase also models message_factory.MessageFactory. + + The symbol database can be used to keep a global registry of all protocol + buffer types used within a program. + """ + + def __init__(self): + """Constructor.""" + + self._symbols = {} + self._symbols_by_file = {} + self.pool = descriptor_pool.DescriptorPool() + + def RegisterMessage(self, message): + """Registers the given message type in the local database. + + Args: + message: a message.Message, to be registered. + + Returns: + The provided message. + """ + + desc = message.DESCRIPTOR + self._symbols[desc.full_name] = message + if desc.file.name not in self._symbols_by_file: + self._symbols_by_file[desc.file.name] = {} + self._symbols_by_file[desc.file.name][desc.full_name] = message + self.pool.AddDescriptor(desc) + return message + + def RegisterEnumDescriptor(self, enum_descriptor): + """Registers the given enum descriptor in the local database. + + Args: + enum_descriptor: a descriptor.EnumDescriptor. + + Returns: + The provided descriptor. + """ + self.pool.AddEnumDescriptor(enum_descriptor) + return enum_descriptor + + def RegisterFileDescriptor(self, file_descriptor): + """Registers the given file descriptor in the local database. + + Args: + file_descriptor: a descriptor.FileDescriptor. + + Returns: + The provided descriptor. + """ + self.pool.AddFileDescriptor(file_descriptor) + + def GetSymbol(self, symbol): + """Tries to find a symbol in the local database. + + Currently, this method only returns message.Message instances, however, if + may be extended in future to support other symbol types. + + Args: + symbol: A str, a protocol buffer symbol. + + Returns: + A Python class corresponding to the symbol. + + Raises: + KeyError: if the symbol could not be found. + """ + + return self._symbols[symbol] + + def GetPrototype(self, descriptor): + """Builds a proto2 message class based on the passed in descriptor. + + Passing a descriptor with a fully qualified name matching a previous + invocation will cause the same class to be returned. + + Args: + descriptor: The descriptor to build from. + + Returns: + A class describing the passed in descriptor. + """ + + return self.GetSymbol(descriptor.full_name) + + def GetMessages(self, files): + """Gets all the messages from a specified file. + + This will find and resolve dependencies, failing if they are not registered + in the symbol database. + + + Args: + files: The file names to extract messages from. + + Returns: + A dictionary mapping proto names to the message classes. This will include + any dependent messages as well as any messages defined in the same file as + a specified message. + + Raises: + KeyError: if a file could not be found. + """ + + result = {} + for f in files: + result.update(self._symbols_by_file[f]) + return result + +_DEFAULT = SymbolDatabase() + + +def Default(): + """Returns the default SymbolDatabase.""" + return _DEFAULT diff --git a/python/google/protobuf/text_encoding.py b/python/google/protobuf/text_encoding.py new file mode 100644 index 00000000..ed0aabf7 --- /dev/null +++ b/python/google/protobuf/text_encoding.py @@ -0,0 +1,110 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#PY25 compatible for GAE. +# +"""Encoding related utilities.""" + +import re +import sys ##PY25 + +# Lookup table for utf8 +_cescape_utf8_to_str = [chr(i) for i in xrange(0, 256)] +_cescape_utf8_to_str[9] = r'\t' # optional escape +_cescape_utf8_to_str[10] = r'\n' # optional escape +_cescape_utf8_to_str[13] = r'\r' # optional escape +_cescape_utf8_to_str[39] = r"\'" # optional escape + +_cescape_utf8_to_str[34] = r'\"' # necessary escape +_cescape_utf8_to_str[92] = r'\\' # necessary escape + +# Lookup table for non-utf8, with necessary escapes at (o >= 127 or o < 32) +_cescape_byte_to_str = ([r'\%03o' % i for i in xrange(0, 32)] + + [chr(i) for i in xrange(32, 127)] + + [r'\%03o' % i for i in xrange(127, 256)]) +_cescape_byte_to_str[9] = r'\t' # optional escape +_cescape_byte_to_str[10] = r'\n' # optional escape +_cescape_byte_to_str[13] = r'\r' # optional escape +_cescape_byte_to_str[39] = r"\'" # optional escape + +_cescape_byte_to_str[34] = r'\"' # necessary escape +_cescape_byte_to_str[92] = r'\\' # necessary escape + + +def CEscape(text, as_utf8): + """Escape a bytes string for use in an ascii protocol buffer. + + text.encode('string_escape') does not seem to satisfy our needs as it + encodes unprintable characters using two-digit hex escapes whereas our + C++ unescaping function allows hex escapes to be any length. So, + "\0011".encode('string_escape') ends up being "\\x011", which will be + decoded in C++ as a single-character string with char code 0x11. + + Args: + text: A byte string to be escaped + as_utf8: Specifies if result should be returned in UTF-8 encoding + Returns: + Escaped string + """ + # PY3 hack: make Ord work for str and bytes: + # //platforms/networking/data uses unicode here, hence basestring. + Ord = ord if isinstance(text, basestring) else lambda x: x + if as_utf8: + return ''.join(_cescape_utf8_to_str[Ord(c)] for c in text) + return ''.join(_cescape_byte_to_str[Ord(c)] for c in text) + + +_CUNESCAPE_HEX = re.compile(r'(\\+)x([0-9a-fA-F])(?![0-9a-fA-F])') +_cescape_highbit_to_str = ([chr(i) for i in range(0, 127)] + + [r'\%03o' % i for i in range(127, 256)]) + + +def CUnescape(text): + """Unescape a text string with C-style escape sequences to UTF-8 bytes.""" + + def ReplaceHex(m): + # Only replace the match if the number of leading back slashes is odd. i.e. + # the slash itself is not escaped. + if len(m.group(1)) & 1: + return m.group(1) + 'x0' + m.group(2) + return m.group(0) + + # This is required because the 'string_escape' encoding doesn't + # allow single-digit hex escapes (like '\xf'). + result = _CUNESCAPE_HEX.sub(ReplaceHex, text) + + if sys.version_info[0] < 3: ##PY25 +##!PY25 if str is bytes: # PY2 + return result.decode('string_escape') + result = ''.join(_cescape_highbit_to_str[ord(c)] for c in result) + return (result.encode('ascii') # Make it bytes to allow decode. + .decode('unicode_escape') + # Make it bytes again to return the proper type. + .encode('raw_unicode_escape')) diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py index 24dd07f2..50f76f22 100755 --- a/python/google/protobuf/text_format.py +++ b/python/google/protobuf/text_format.py @@ -28,6 +28,10 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#PY25 compatible for GAE. +# +# Copyright 2007 Google Inc. All Rights Reserved. + """Contains routines for printing protocol messages in text format.""" __author__ = 'kenton@google.com (Kenton Varda)' @@ -35,12 +39,12 @@ __author__ = 'kenton@google.com (Kenton Varda)' import cStringIO import re -from collections import deque from google.protobuf.internal import type_checkers from google.protobuf import descriptor +from google.protobuf import text_encoding -__all__ = [ 'MessageToString', 'PrintMessage', 'PrintField', - 'PrintFieldValue', 'Merge' ] +__all__ = ['MessageToString', 'PrintMessage', 'PrintField', + 'PrintFieldValue', 'Merge'] _INTEGER_CHECKERS = (type_checkers.Uint32ValueChecker(), @@ -49,15 +53,47 @@ _INTEGER_CHECKERS = (type_checkers.Uint32ValueChecker(), type_checkers.Int64ValueChecker()) _FLOAT_INFINITY = re.compile('-?inf(?:inity)?f?', re.IGNORECASE) _FLOAT_NAN = re.compile('nanf?', re.IGNORECASE) +_FLOAT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_FLOAT, + descriptor.FieldDescriptor.CPPTYPE_DOUBLE]) + +class Error(Exception): + """Top-level module error for text_format.""" -class ParseError(Exception): + +class ParseError(Error): """Thrown in case of ASCII parsing error.""" -def MessageToString(message, as_utf8=False, as_one_line=False): +def MessageToString(message, as_utf8=False, as_one_line=False, + pointy_brackets=False, use_index_order=False, + float_format=None): + """Convert protobuf message to text format. + + Floating point values can be formatted compactly with 15 digits of + precision (which is the most that IEEE 754 "double" can guarantee) + using float_format='.15g'. + + Args: + message: The protocol buffers message. + as_utf8: Produce text output in UTF8 format. + as_one_line: Don't introduce newlines between fields. + pointy_brackets: If True, use angle brackets instead of curly braces for + nesting. + use_index_order: If True, print fields of a proto message using the order + defined in source code instead of the field number. By default, use the + field number order. + float_format: If set, use this to specify floating point number formatting + (per the "Format Specification Mini-Language"); otherwise, str() is used. + + Returns: + A string of the text formatted protocol buffer message. + """ out = cStringIO.StringIO() - PrintMessage(message, out, as_utf8=as_utf8, as_one_line=as_one_line) + PrintMessage(message, out, as_utf8=as_utf8, as_one_line=as_one_line, + pointy_brackets=pointy_brackets, + use_index_order=use_index_order, + float_format=float_format) result = out.getvalue() out.close() if as_one_line: @@ -65,20 +101,30 @@ def MessageToString(message, as_utf8=False, as_one_line=False): return result -def PrintMessage(message, out, indent=0, as_utf8=False, as_one_line=False): - for field, value in message.ListFields(): +def PrintMessage(message, out, indent=0, as_utf8=False, as_one_line=False, + pointy_brackets=False, use_index_order=False, + float_format=None): + fields = message.ListFields() + if use_index_order: + fields.sort(key=lambda x: x[0].index) + for field, value in fields: if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: for element in value: - PrintField(field, element, out, indent, as_utf8, as_one_line) + PrintField(field, element, out, indent, as_utf8, as_one_line, + pointy_brackets=pointy_brackets, + float_format=float_format) else: - PrintField(field, value, out, indent, as_utf8, as_one_line) + PrintField(field, value, out, indent, as_utf8, as_one_line, + pointy_brackets=pointy_brackets, + float_format=float_format) -def PrintField(field, value, out, indent=0, as_utf8=False, as_one_line=False): +def PrintField(field, value, out, indent=0, as_utf8=False, as_one_line=False, + pointy_brackets=False, float_format=None): """Print a single field name/value pair. For repeated fields, the value should be a single element.""" - out.write(' ' * indent); + out.write(' ' * indent) if field.is_extension: out.write('[') if (field.containing_type.GetOptions().message_set_wire_format and @@ -100,27 +146,41 @@ def PrintField(field, value, out, indent=0, as_utf8=False, as_one_line=False): # don't include it. out.write(': ') - PrintFieldValue(field, value, out, indent, as_utf8, as_one_line) + PrintFieldValue(field, value, out, indent, as_utf8, as_one_line, + pointy_brackets=pointy_brackets, + float_format=float_format) if as_one_line: out.write(' ') else: out.write('\n') -def PrintFieldValue(field, value, out, indent=0, - as_utf8=False, as_one_line=False): +def PrintFieldValue(field, value, out, indent=0, as_utf8=False, + as_one_line=False, pointy_brackets=False, + float_format=None): """Print a single field value (not including name). For repeated fields, the value should be a single element.""" + if pointy_brackets: + openb = '<' + closeb = '>' + else: + openb = '{' + closeb = '}' + if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: if as_one_line: - out.write(' { ') - PrintMessage(value, out, indent, as_utf8, as_one_line) - out.write('}') + out.write(' %s ' % openb) + PrintMessage(value, out, indent, as_utf8, as_one_line, + pointy_brackets=pointy_brackets, + float_format=float_format) + out.write(closeb) else: - out.write(' {\n') - PrintMessage(value, out, indent + 2, as_utf8, as_one_line) - out.write(' ' * indent + '}') + out.write(' %s\n' % openb) + PrintMessage(value, out, indent + 2, as_utf8, as_one_line, + pointy_brackets=pointy_brackets, + float_format=float_format) + out.write(' ' * indent + closeb) elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM: enum_value = field.enum_type.values_by_number.get(value, None) if enum_value is not None: @@ -129,41 +189,125 @@ def PrintFieldValue(field, value, out, indent=0, out.write(str(value)) elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING: out.write('\"') - if type(value) is unicode: - out.write(_CEscape(value.encode('utf-8'), as_utf8)) + if isinstance(value, unicode): + out_value = value.encode('utf-8') + else: + out_value = value + if field.type == descriptor.FieldDescriptor.TYPE_BYTES: + # We need to escape non-UTF8 chars in TYPE_BYTES field. + out_as_utf8 = False else: - out.write(_CEscape(value, as_utf8)) + out_as_utf8 = as_utf8 + out.write(text_encoding.CEscape(out_value, out_as_utf8)) out.write('\"') elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL: if value: - out.write("true") + out.write('true') else: - out.write("false") + out.write('false') + elif field.cpp_type in _FLOAT_TYPES and float_format is not None: + out.write('{1:{0}}'.format(float_format, value)) else: out.write(str(value)) +def _ParseOrMerge(lines, message, allow_multiple_scalars): + """Converts an ASCII representation of a protocol message into a message. + + Args: + lines: Lines of a message's ASCII representation. + message: A protocol buffer message to merge into. + allow_multiple_scalars: Determines if repeated values for a non-repeated + field are permitted, e.g., the string "foo: 1 foo: 2" for a + required/optional field named "foo". + + Raises: + ParseError: On ASCII parsing problems. + """ + tokenizer = _Tokenizer(lines) + while not tokenizer.AtEnd(): + _MergeField(tokenizer, message, allow_multiple_scalars) + + +def Parse(text, message): + """Parses an ASCII representation of a protocol message into a message. + + Args: + text: Message ASCII representation. + message: A protocol buffer message to merge into. + + Returns: + The same message passed as argument. + + Raises: + ParseError: On ASCII parsing problems. + """ + if not isinstance(text, str): text = text.decode('utf-8') + return ParseLines(text.split('\n'), message) + + def Merge(text, message): - """Merges an ASCII representation of a protocol message into a message. + """Parses an ASCII representation of a protocol message into a message. + + Like Parse(), but allows repeated values for a non-repeated field, and uses + the last one. Args: text: Message ASCII representation. message: A protocol buffer message to merge into. + Returns: + The same message passed as argument. + Raises: ParseError: On ASCII parsing problems. """ - tokenizer = _Tokenizer(text) - while not tokenizer.AtEnd(): - _MergeField(tokenizer, message) + return MergeLines(text.split('\n'), message) -def _MergeField(tokenizer, message): +def ParseLines(lines, message): + """Parses an ASCII representation of a protocol message into a message. + + Args: + lines: An iterable of lines of a message's ASCII representation. + message: A protocol buffer message to merge into. + + Returns: + The same message passed as argument. + + Raises: + ParseError: On ASCII parsing problems. + """ + _ParseOrMerge(lines, message, False) + return message + + +def MergeLines(lines, message): + """Parses an ASCII representation of a protocol message into a message. + + Args: + lines: An iterable of lines of a message's ASCII representation. + message: A protocol buffer message to merge into. + + Returns: + The same message passed as argument. + + Raises: + ParseError: On ASCII parsing problems. + """ + _ParseOrMerge(lines, message, True) + return message + + +def _MergeField(tokenizer, message, allow_multiple_scalars): """Merges a single protocol message field into a message. Args: tokenizer: A tokenizer to parse the field name and values. message: A protocol message to record the data. + allow_multiple_scalars: Determines if repeated values for a non-repeated + field are permitted, e.g., the string "foo: 1 foo: 2" for a + required/optional field named "foo". Raises: ParseError: In case of ASCII parsing problems. @@ -179,7 +323,9 @@ def _MergeField(tokenizer, message): raise tokenizer.ParseErrorPreviousToken( 'Message type "%s" does not have extensions.' % message_descriptor.full_name) + # pylint: disable=protected-access field = message.Extensions._FindExtensionByName(name) + # pylint: enable=protected-access if not field: raise tokenizer.ParseErrorPreviousToken( 'Extension "%s" not registered.' % name) @@ -233,18 +379,26 @@ def _MergeField(tokenizer, message): while not tokenizer.TryConsume(end_token): if tokenizer.AtEnd(): raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % (end_token)) - _MergeField(tokenizer, sub_message) + _MergeField(tokenizer, sub_message, allow_multiple_scalars) else: - _MergeScalarField(tokenizer, message, field) + _MergeScalarField(tokenizer, message, field, allow_multiple_scalars) + # For historical reasons, fields may optionally be separated by commas or + # semicolons. + if not tokenizer.TryConsume(','): + tokenizer.TryConsume(';') -def _MergeScalarField(tokenizer, message, field): + +def _MergeScalarField(tokenizer, message, field, allow_multiple_scalars): """Merges a single protocol message scalar field into a message. Args: tokenizer: A tokenizer to parse the field value. message: A protocol message to record the data. field: The descriptor of the field to be merged. + allow_multiple_scalars: Determines if repeated values for a non-repeated + field are permitted, e.g., the string "foo: 1 foo: 2" for a + required/optional field named "foo". Raises: ParseError: In case of ASCII parsing problems. @@ -288,9 +442,19 @@ def _MergeScalarField(tokenizer, message, field): getattr(message, field.name).append(value) else: if field.is_extension: - message.Extensions[field] = value + if not allow_multiple_scalars and message.HasExtension(field): + raise tokenizer.ParseErrorPreviousToken( + 'Message type "%s" should not have multiple "%s" extensions.' % + (message.DESCRIPTOR.full_name, field.full_name)) + else: + message.Extensions[field] = value else: - setattr(message, field.name, value) + if not allow_multiple_scalars and message.HasField(field.name): + raise tokenizer.ParseErrorPreviousToken( + 'Message type "%s" should not have multiple "%s" fields.' % + (message.DESCRIPTOR.full_name, field.name)) + else: + setattr(message, field.name, value) class _Tokenizer(object): @@ -308,20 +472,19 @@ class _Tokenizer(object): '[0-9+-][0-9a-zA-Z_.+-]*|' # a number '\"([^\"\n\\\\]|\\\\.)*(\"|\\\\?$)|' # a double-quoted string '\'([^\'\n\\\\]|\\\\.)*(\'|\\\\?$)') # a single-quoted string - _IDENTIFIER = re.compile('\w+') - - def __init__(self, text_message): - self._text_message = text_message + _IDENTIFIER = re.compile(r'\w+') + def __init__(self, lines): self._position = 0 self._line = -1 self._column = 0 self._token_start = None self.token = '' - self._lines = deque(text_message.split('\n')) + self._lines = iter(lines) self._current_line = '' self._previous_line = 0 self._previous_column = 0 + self._more_lines = True self._SkipWhitespace() self.NextToken() @@ -331,16 +494,19 @@ class _Tokenizer(object): Returns: True iff the end was reached. """ - return self.token == '' + return not self.token def _PopLine(self): while len(self._current_line) <= self._column: - if not self._lines: + try: + self._current_line = self._lines.next() + except StopIteration: self._current_line = '' + self._more_lines = False return - self._line += 1 - self._column = 0 - self._current_line = self._lines.popleft() + else: + self._line += 1 + self._column = 0 def _SkipWhitespace(self): while True: @@ -497,9 +663,9 @@ class _Tokenizer(object): Raises: ParseError: If a string value couldn't be consumed. """ - bytes = self.ConsumeByteString() + the_bytes = self.ConsumeByteString() try: - return unicode(bytes, 'utf-8') + return unicode(the_bytes, 'utf-8') except UnicodeDecodeError, e: raise self._StringParseError(e) @@ -512,10 +678,11 @@ class _Tokenizer(object): Raises: ParseError: If a byte array value couldn't be consumed. """ - list = [self._ConsumeSingleByteString()] - while len(self.token) > 0 and self.token[0] in ('\'', '"'): - list.append(self._ConsumeSingleByteString()) - return "".join(list) + the_list = [self._ConsumeSingleByteString()] + while self.token and self.token[0] in ('\'', '"'): + the_list.append(self._ConsumeSingleByteString()) + return ''.encode('latin1').join(the_list) ##PY25 +##!PY25 return b''.join(the_list) def _ConsumeSingleByteString(self): """Consume one token of a string literal. @@ -532,7 +699,7 @@ class _Tokenizer(object): raise self._ParseError('String missing ending quote.') try: - result = _CUnescape(text[1:-1]) + result = text_encoding.CUnescape(text[1:-1]) except ValueError, e: raise self._ParseError(str(e)) self.NextToken() @@ -574,7 +741,7 @@ class _Tokenizer(object): self._column += len(self.token) self._SkipWhitespace() - if not self._lines and len(self._current_line) <= self._column: + if not self._more_lines: self.token = '' return @@ -586,45 +753,6 @@ class _Tokenizer(object): self.token = self._current_line[self._column] -# text.encode('string_escape') does not seem to satisfy our needs as it -# encodes unprintable characters using two-digit hex escapes whereas our -# C++ unescaping function allows hex escapes to be any length. So, -# "\0011".encode('string_escape') ends up being "\\x011", which will be -# decoded in C++ as a single-character string with char code 0x11. -def _CEscape(text, as_utf8): - def escape(c): - o = ord(c) - if o == 10: return r"\n" # optional escape - if o == 13: return r"\r" # optional escape - if o == 9: return r"\t" # optional escape - if o == 39: return r"\'" # optional escape - - if o == 34: return r'\"' # necessary escape - if o == 92: return r"\\" # necessary escape - - # necessary escapes - if not as_utf8 and (o >= 127 or o < 32): return "\\%03o" % o - return c - return "".join([escape(c) for c in text]) - - -_CUNESCAPE_HEX = re.compile(r'(\\+)x([0-9a-fA-F])(?![0-9a-fA-F])') - - -def _CUnescape(text): - def ReplaceHex(m): - # Only replace the match if the number of leading back slashes is odd. i.e. - # the slash itself is not escaped. - if len(m.group(1)) & 1: - return m.group(1) + 'x0' + m.group(2) - return m.group(0) - - # This is required because the 'string_escape' encoding doesn't - # allow single-digit hex escapes (like '\xf'). - result = _CUNESCAPE_HEX.sub(ReplaceHex, text) - return result.decode('string_escape') - - def ParseInteger(text, is_signed=False, is_long=False): """Parses an integer. @@ -641,7 +769,13 @@ def ParseInteger(text, is_signed=False, is_long=False): """ # Do the actual parsing. Exception handling is propagated to caller. try: - result = int(text, 0) + # We force 32-bit values to int and 64-bit values to long to make + # alternate implementations where the distinction is more significant + # (e.g. the C++ implementation) simpler. + if is_long: + result = long(text, 0) + else: + result = int(text, 0) except ValueError: raise ValueError('Couldn\'t parse integer: %s' % text) diff --git a/python/setup.py b/python/setup.py index 24a446a9..0e95adad 100755 --- a/python/setup.py +++ b/python/setup.py @@ -8,17 +8,14 @@ import subprocess # We must use setuptools, not distutils, because we need to use the # namespace_packages option for the "google" package. try: - from setuptools import setup, Extension + from ez_setup import use_setuptools + use_setuptools() + from setuptools import setup, Extension, __version__ except ImportError: - try: - from ez_setup import use_setuptools - use_setuptools() - from setuptools import setup, Extension - except ImportError: - sys.stderr.write( - "Could not import setuptools; make sure you have setuptools or " - "ez_setup installed.\n") - raise + sys.stderr.write( + "Could not import setuptools; make sure you have setuptools or " + "ez_setup installed.\n") + raise from distutils.command.clean import clean as _clean from distutils.command.build_py import build_py as _build_py from distutils.spawn import find_executable @@ -49,7 +46,7 @@ def generate_proto(source): if (not os.path.exists(output) or (os.path.exists(source) and os.path.getmtime(source) > os.path.getmtime(output))): - print "Generating %s..." % output + print ("Generating %s..." % output) if not os.path.exists(source): sys.stderr.write("Can't find required file: %s\n" % source) @@ -72,7 +69,6 @@ def GenerateUnittestProtos(): generate_proto("../src/google/protobuf/unittest_import_public.proto") generate_proto("../src/google/protobuf/unittest_mset.proto") generate_proto("../src/google/protobuf/unittest_no_generic_services.proto") - generate_proto("google/protobuf/internal/compatibility_mode_test.proto") generate_proto("google/protobuf/internal/descriptor_pool_test1.proto") generate_proto("google/protobuf/internal/descriptor_pool_test2.proto") generate_proto("google/protobuf/internal/test_bad_identifiers.proto") @@ -82,75 +78,7 @@ def GenerateUnittestProtos(): generate_proto("google/protobuf/internal/more_messages.proto") generate_proto("google/protobuf/internal/factory_test1.proto") generate_proto("google/protobuf/internal/factory_test2.proto") - -def MakeTestSuite(): - # This is apparently needed on some systems to make sure that the tests - # work even if a previous version is already installed. - if 'google' in sys.modules: - del sys.modules['google'] - GenerateUnittestProtos() - - import unittest - import google.protobuf.internal.generator_test as generator_test - import google.protobuf.internal.descriptor_test as descriptor_test - import google.protobuf.internal.reflection_test as reflection_test - import google.protobuf.internal.service_reflection_test \ - as service_reflection_test - import google.protobuf.internal.text_format_test as text_format_test - import google.protobuf.internal.wire_format_test as wire_format_test - import google.protobuf.internal.unknown_fields_test as unknown_fields_test - import google.protobuf.internal.descriptor_database_test \ - as descriptor_database_test - import google.protobuf.internal.descriptor_pool_test as descriptor_pool_test - import google.protobuf.internal.message_factory_test as message_factory_test - import google.protobuf.internal.message_cpp_test as message_cpp_test - import google.protobuf.internal.reflection_cpp_generated_test \ - as reflection_cpp_generated_test - import google.protobuf.internal.api_implementation_default_test \ - as api_implementation_default_test - import google.protobuf.internal.descriptor_cpp2_test as descriptor_cpp2_test - import google.protobuf.internal.descriptor_python_test \ - as descriptor_python_test - import google.protobuf.internal.message_factory_cpp2_test \ - as message_factory_cpp2_test - import google.protobuf.internal.message_factory_cpp_test \ - as message_factory_cpp_test - import google.protobuf.internal.message_factory_python_test \ - as message_factory_python_test - import google.protobuf.internal.message_python_test as message_python_test - import google.protobuf.internal.reflection_cpp2_generated_test \ - as reflection_cpp2_generated_test - import google.protobuf.internal.symbol_database_test as symbol_database_test - import google.protobuf.internal.text_encoding_test as text_encoding_test - - loader = unittest.defaultTestLoader - suite = unittest.TestSuite() - for test in [ generator_test, - descriptor_test, - reflection_test, - service_reflection_test, - text_format_test, - wire_format_test, - unknown_fields_test, - descriptor_pool_test, - message_factory_test, - message_cpp_test, - reflection_cpp_generated_test, - api_implementation_default_test, - descriptor_cpp2_test, - descriptor_python_test, - message_factory_cpp2_test, - message_factory_cpp_test, - message_factory_python_test, - message_python_test, - reflection_cpp2_generated_test, - symbol_database_test, - text_encoding_test ]: - - suite.addTest(loader.loadTestsFromModule(test)) - - return suite - + generate_proto("google/protobuf/pyext/python.proto") class clean(_clean): def run(self): @@ -170,31 +98,52 @@ class build_py(_build_py): # Generate necessary .proto file if it doesn't exist. generate_proto("../src/google/protobuf/descriptor.proto") generate_proto("../src/google/protobuf/compiler/plugin.proto") - - # Make sure google.protobuf.compiler is a valid package. - open('google/protobuf/compiler/__init__.py', 'a').close() + GenerateUnittestProtos() + + # Make sure google.protobuf/** are valid packages. + for path in ['', 'internal/', 'compiler/', 'pyext/']: + try: + open('google/protobuf/%s__init__.py' % path, 'a').close() + except EnvironmentError: + pass # _build_py is an old-style class, so super() doesn't work. _build_py.run(self) + # TODO(mrovner): Subclass to run 2to3 on some files only. + # Tracing what https://wiki.python.org/moin/PortingPythonToPy3k's "Approach 2" + # section on how to get 2to3 to run on source files during install under + # Python 3. This class seems like a good place to put logic that calls + # python3's distutils.util.run_2to3 on the subset of the files we have in our + # release that are subject to conversion. + # See code reference in previous code review. -if __name__ == '__main__': - ext_module_list = [] +if __name__ == '__main__': + print(__version__) # C++ implementation extension - if os.getenv("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION", "python") == "cpp": - print "Using EXPERIMENTAL C++ Implmenetation." - ext_module_list.append(Extension( - "google.protobuf.internal._net_proto2___python", - [ "google/protobuf/pyext/python_descriptor.cc", - "google/protobuf/pyext/python_protobuf.cc", - "google/protobuf/pyext/python-proto2.cc" ], - include_dirs = [ "." ], - libraries = [ "protobuf" ])) + nocpp = '--nocpp_implementation' + if nocpp in sys.argv: + ext_module_list = [] + sys.argv.remove(nocpp) + else: + nocpp = False + ext_module_list = [Extension( + "google.protobuf.pyext._message", + [ "google/protobuf/pyext/descriptor.cc", + "google/protobuf/pyext/message.cc", + "google/protobuf/pyext/extension_dict.cc", + "google/protobuf/pyext/repeated_scalar_container.cc", + "google/protobuf/pyext/repeated_composite_container.cc" ], + define_macros=[('GOOGLE_PROTOBUF_HAS_ONEOF', '1')], + include_dirs = [ ".", "../src" ], + libraries = [ "protobuf" ], + library_dirs = [ '../src/.libs' ], + )] setup(name = 'protobuf', - version = '2.5.1-pre', + version = '2.6-pre', packages = [ 'google' ], namespace_packages = [ 'google' ], - test_suite = 'setup.MakeTestSuite', + google_test_dir = "google/protobuf/internal", # Must list modules explicitly so that we don't install tests. py_modules = [ 'google.protobuf.internal.api_implementation', @@ -217,9 +166,12 @@ if __name__ == '__main__': 'google.protobuf.reflection', 'google.protobuf.service', 'google.protobuf.service_reflection', - 'google.protobuf.text_format' ], + 'google.protobuf.symbol_database', + 'google.protobuf.text_encoding', + 'google.protobuf.text_format'], cmdclass = { 'clean': clean, 'build_py': build_py }, install_requires = ['setuptools'], + setup_requires = ['google-apputils'], ext_modules = ext_module_list, url = 'http://code.google.com/p/protobuf/', maintainer = maintainer_email, @@ -228,4 +180,5 @@ if __name__ == '__main__': description = 'Protocol Buffers', long_description = "Protocol Buffers are Google's data interchange format.", + use_2to3=True, ) diff --git a/src/Makefile.am b/src/Makefile.am index ada36ad4..3a164ff1 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -101,6 +101,7 @@ libprotobuf_lite_la_SOURCES = \ google/protobuf/stubs/once.cc \ google/protobuf/stubs/hash.h \ google/protobuf/stubs/map_util.h \ + google/protobuf/stubs/shared_ptr.h \ google/protobuf/stubs/stl_util.h \ google/protobuf/stubs/stringprintf.cc \ google/protobuf/stubs/stringprintf.h \ @@ -246,7 +247,10 @@ EXTRA_DIST = \ google/protobuf/testdata/golden_message_oneof_implemented \ google/protobuf/testdata/golden_packed_fields_message \ google/protobuf/testdata/text_format_unittest_data_oneof_implemented.txt \ + google/protobuf/testdata/text_format_unittest_data_pointy.txt \ + google/protobuf/testdata/text_format_unittest_data_pointy_oneof_implemented.txt \ google/protobuf/testdata/text_format_unittest_extensions_data.txt \ + google/protobuf/testdata/text_format_unittest_extensions_data_pointy.txt \ google/protobuf/package_info.h \ google/protobuf/io/package_info.h \ google/protobuf/compiler/package_info.h \ diff --git a/src/google/protobuf/compiler/python/python_generator.cc b/src/google/protobuf/compiler/python/python_generator.cc index 067d856b..113873a2 100644 --- a/src/google/protobuf/compiler/python/python_generator.cc +++ b/src/google/protobuf/compiler/python/python_generator.cc @@ -642,7 +642,22 @@ void Generator::PrintDescriptor(const Descriptor& message_descriptor) const { "end", SimpleItoa(range->end)); } printer_->Print("],\n"); - + printer_->Print("oneofs=[\n"); + printer_->Indent(); + for (int i = 0; i < message_descriptor.oneof_decl_count(); ++i) { + const OneofDescriptor* desc = message_descriptor.oneof_decl(i); + map m; + m["name"] = desc->name(); + m["full_name"] = desc->full_name(); + m["index"] = SimpleItoa(desc->index()); + printer_->Print( + m, + "_descriptor.OneofDescriptor(\n" + " name='$name$', full_name='$full_name$',\n" + " index=$index$, containing_type=None, fields=[]),\n"); + } + printer_->Outdent(); + printer_->Print("],\n"); // Serialization of proto DescriptorProto edp; PrintSerializedPbInterval(message_descriptor, edp); @@ -743,6 +758,23 @@ void Generator::FixForeignFieldsInDescriptor( const EnumDescriptor& enum_descriptor = *descriptor.enum_type(i); FixContainingTypeInDescriptor(enum_descriptor, &descriptor); } + for (int i = 0; i < descriptor.oneof_decl_count(); ++i) { + map m; + const OneofDescriptor* oneof = descriptor.oneof_decl(i); + m["descriptor_name"] = ModuleLevelDescriptorName(descriptor); + m["oneof_name"] = oneof->name(); + for (int j = 0; j < oneof->field_count(); ++j) { + m["field_name"] = oneof->field(j)->name(); + printer_->Print( + m, + "$descriptor_name$.oneofs_by_name['$oneof_name$'].fields.append(\n" + " $descriptor_name$.fields_by_name['$field_name$'])\n"); + printer_->Print( + m, + "$descriptor_name$.fields_by_name['$field_name$'].containing_oneof = " + "$descriptor_name$.oneofs_by_name['$oneof_name$']\n"); + } + } } void Generator::AddMessageToFileDescriptor(const Descriptor& descriptor) const { diff --git a/src/google/protobuf/testdata/text_format_unittest_data_pointy_oneof_implemented.txt b/src/google/protobuf/testdata/text_format_unittest_data_pointy_oneof_implemented.txt new file mode 100644 index 00000000..95109f62 --- /dev/null +++ b/src/google/protobuf/testdata/text_format_unittest_data_pointy_oneof_implemented.txt @@ -0,0 +1,129 @@ +optional_int32: 101 +optional_int64: 102 +optional_uint32: 103 +optional_uint64: 104 +optional_sint32: 105 +optional_sint64: 106 +optional_fixed32: 107 +optional_fixed64: 108 +optional_sfixed32: 109 +optional_sfixed64: 110 +optional_float: 111 +optional_double: 112 +optional_bool: true +optional_string: "115" +optional_bytes: "116" +OptionalGroup < + a: 117 +> +optional_nested_message < + bb: 118 +> +optional_foreign_message < + c: 119 +> +optional_import_message < + d: 120 +> +optional_nested_enum: BAZ +optional_foreign_enum: FOREIGN_BAZ +optional_import_enum: IMPORT_BAZ +optional_string_piece: "124" +optional_cord: "125" +optional_public_import_message < + e: 126 +> +optional_lazy_message < + bb: 127 +> +repeated_int32: 201 +repeated_int32: 301 +repeated_int64: 202 +repeated_int64: 302 +repeated_uint32: 203 +repeated_uint32: 303 +repeated_uint64: 204 +repeated_uint64: 304 +repeated_sint32: 205 +repeated_sint32: 305 +repeated_sint64: 206 +repeated_sint64: 306 +repeated_fixed32: 207 +repeated_fixed32: 307 +repeated_fixed64: 208 +repeated_fixed64: 308 +repeated_sfixed32: 209 +repeated_sfixed32: 309 +repeated_sfixed64: 210 +repeated_sfixed64: 310 +repeated_float: 211 +repeated_float: 311 +repeated_double: 212 +repeated_double: 312 +repeated_bool: true +repeated_bool: false +repeated_string: "215" +repeated_string: "315" +repeated_bytes: "216" +repeated_bytes: "316" +RepeatedGroup < + a: 217 +> +RepeatedGroup < + a: 317 +> +repeated_nested_message < + bb: 218 +> +repeated_nested_message < + bb: 318 +> +repeated_foreign_message < + c: 219 +> +repeated_foreign_message < + c: 319 +> +repeated_import_message < + d: 220 +> +repeated_import_message < + d: 320 +> +repeated_nested_enum: BAR +repeated_nested_enum: BAZ +repeated_foreign_enum: FOREIGN_BAR +repeated_foreign_enum: FOREIGN_BAZ +repeated_import_enum: IMPORT_BAR +repeated_import_enum: IMPORT_BAZ +repeated_string_piece: "224" +repeated_string_piece: "324" +repeated_cord: "225" +repeated_cord: "325" +repeated_lazy_message < + bb: 227 +> +repeated_lazy_message < + bb: 327 +> +default_int32: 401 +default_int64: 402 +default_uint32: 403 +default_uint64: 404 +default_sint32: 405 +default_sint64: 406 +default_fixed32: 407 +default_fixed64: 408 +default_sfixed32: 409 +default_sfixed64: 410 +default_float: 411 +default_double: 412 +default_bool: false +default_string: "415" +default_bytes: "416" +default_nested_enum: FOO +default_foreign_enum: FOREIGN_FOO +default_import_enum: IMPORT_FOO +default_string_piece: "424" +default_cord: "425" +oneof_bytes: "604" -- cgit v1.2.3