From bde4a3254a7de58911941b0fbf38e9dd992de973 Mon Sep 17 00:00:00 2001 From: "jieluo@google.com" Date: Tue, 12 Aug 2014 21:10:30 +0000 Subject: down integrate python opensource to svn --- python/google/protobuf/descriptor_pool.py | 394 +++++++++++++++++++----------- 1 file changed, 255 insertions(+), 139 deletions(-) (limited to 'python/google/protobuf/descriptor_pool.py') diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py index 8f1f4457..372f458f 100644 --- a/python/google/protobuf/descriptor_pool.py +++ b/python/google/protobuf/descriptor_pool.py @@ -49,13 +49,34 @@ Below is a straightforward example on how to use this class: The message descriptor can be used in conjunction with the message_factory module in order to create a protocol buffer class that can be encoded and decoded. + +If you want to get a Python class for the specified proto, use the +helper functions inside google.protobuf.message_factory +directly instead of this class. """ __author__ = 'matthewtoia@google.com (Matt Toia)' -from google.protobuf import descriptor_pb2 +import sys + from google.protobuf import descriptor from google.protobuf import descriptor_database +from google.protobuf import text_encoding + + +def _NormalizeFullyQualifiedName(name): + """Remove leading period from fully-qualified type name. + + Due to b/13860351 in descriptor_database.py, types in the root namespace are + generated with a leading period. This function removes that prefix. + + Args: + name: A str, the fully-qualified symbol name. + + Returns: + A str, the normalized fully-qualified symbol name. + """ + return name.lstrip('.') class DescriptorPool(object): @@ -89,6 +110,51 @@ class DescriptorPool(object): self._internal_db.Add(file_desc_proto) + def AddDescriptor(self, desc): + """Adds a Descriptor to the pool, non-recursively. + + If the Descriptor contains nested messages or enums, the caller must + explicitly register them. This method also registers the FileDescriptor + associated with the message. + + Args: + desc: A Descriptor. + """ + if not isinstance(desc, descriptor.Descriptor): + raise TypeError('Expected instance of descriptor.Descriptor.') + + self._descriptors[desc.full_name] = desc + self.AddFileDescriptor(desc.file) + + def AddEnumDescriptor(self, enum_desc): + """Adds an EnumDescriptor to the pool. + + This method also registers the FileDescriptor associated with the message. + + Args: + enum_desc: An EnumDescriptor. + """ + + if not isinstance(enum_desc, descriptor.EnumDescriptor): + raise TypeError('Expected instance of descriptor.EnumDescriptor.') + + self._enum_descriptors[enum_desc.full_name] = enum_desc + self.AddFileDescriptor(enum_desc.file) + + def AddFileDescriptor(self, file_desc): + """Adds a FileDescriptor to the pool, non-recursively. + + If the FileDescriptor contains messages or enums, the caller must explicitly + register them. + + Args: + file_desc: A FileDescriptor. + """ + + if not isinstance(file_desc, descriptor.FileDescriptor): + raise TypeError('Expected instance of descriptor.FileDescriptor.') + self._file_descriptors[file_desc.name] = file_desc + def FindFileByName(self, file_name): """Gets a FileDescriptor by file name. @@ -102,9 +168,15 @@ class DescriptorPool(object): KeyError: if the file can not be found in the pool. """ + try: + return self._file_descriptors[file_name] + except KeyError: + pass + try: file_proto = self._internal_db.FindFileByName(file_name) - except KeyError as error: + except KeyError: + _, error, _ = sys.exc_info() #PY25 compatible for GAE. if self._descriptor_db: file_proto = self._descriptor_db.FindFileByName(file_name) else: @@ -126,9 +198,21 @@ class DescriptorPool(object): KeyError: if the file can not be found in the pool. """ + symbol = _NormalizeFullyQualifiedName(symbol) + try: + return self._descriptors[symbol].file + except KeyError: + pass + + try: + return self._enum_descriptors[symbol].file + except KeyError: + pass + try: file_proto = self._internal_db.FindFileContainingSymbol(symbol) - except KeyError as error: + except KeyError: + _, error, _ = sys.exc_info() #PY25 compatible for GAE. if self._descriptor_db: file_proto = self._descriptor_db.FindFileContainingSymbol(symbol) else: @@ -147,7 +231,7 @@ class DescriptorPool(object): The descriptor for the named type. """ - full_name = full_name.lstrip('.') # fix inconsistent qualified name formats + full_name = _NormalizeFullyQualifiedName(full_name) if full_name not in self._descriptors: self.FindFileContainingSymbol(full_name) return self._descriptors[full_name] @@ -162,7 +246,7 @@ class DescriptorPool(object): The enum descriptor for the named type. """ - full_name = full_name.lstrip('.') # fix inconsistent qualified name formats + full_name = _NormalizeFullyQualifiedName(full_name) if full_name not in self._enum_descriptors: self.FindFileContainingSymbol(full_name) return self._enum_descriptors[full_name] @@ -181,46 +265,56 @@ class DescriptorPool(object): """ if file_proto.name not in self._file_descriptors: + built_deps = list(self._GetDeps(file_proto.dependency)) + direct_deps = [self.FindFileByName(n) for n in file_proto.dependency] + file_descriptor = descriptor.FileDescriptor( name=file_proto.name, package=file_proto.package, options=file_proto.options, - serialized_pb=file_proto.SerializeToString()) + serialized_pb=file_proto.SerializeToString(), + dependencies=direct_deps) scope = {} - dependencies = list(self._GetDeps(file_proto)) - for dependency in dependencies: - dep_desc = self.FindFileByName(dependency.name) - dep_proto = descriptor_pb2.FileDescriptorProto.FromString( - dep_desc.serialized_pb) - package = '.' + dep_proto.package - package_prefix = package + '.' - - def _strip_package(symbol): - if symbol.startswith(package_prefix): - return symbol[len(package_prefix):] - return symbol - - symbols = list(self._ExtractSymbols(dep_proto.message_type, package)) - scope.update(symbols) - scope.update((_strip_package(k), v) for k, v in symbols) - - symbols = list(self._ExtractEnums(dep_proto.enum_type, package)) - scope.update(symbols) - scope.update((_strip_package(k), v) for k, v in symbols) + # This loop extracts all the message and enum types from all the + # dependencoes of the file_proto. This is necessary to create the + # scope of available message types when defining the passed in + # file proto. + for dependency in built_deps: + scope.update(self._ExtractSymbols( + dependency.message_types_by_name.values())) + scope.update((_PrefixWithDot(enum.full_name), enum) + for enum in dependency.enum_types_by_name.values()) for message_type in file_proto.message_type: message_desc = self._ConvertMessageDescriptor( message_type, file_proto.package, file_descriptor, scope) file_descriptor.message_types_by_name[message_desc.name] = message_desc + for enum_type in file_proto.enum_type: - self._ConvertEnumDescriptor(enum_type, file_proto.package, - file_descriptor, None, scope) - for desc_proto in self._ExtractMessages(file_proto.message_type): - self._SetFieldTypes(desc_proto, scope) + file_descriptor.enum_types_by_name[enum_type.name] = ( + self._ConvertEnumDescriptor(enum_type, file_proto.package, + file_descriptor, None, scope)) + + for index, extension_proto in enumerate(file_proto.extension): + extension_desc = self.MakeFieldDescriptor( + extension_proto, file_proto.package, index, is_extension=True) + extension_desc.containing_type = self._GetTypeFromScope( + file_descriptor.package, extension_proto.extendee, scope) + self.SetFieldType(extension_proto, extension_desc, + file_descriptor.package, scope) + file_descriptor.extensions_by_name[extension_desc.name] = extension_desc + + for desc_proto in file_proto.message_type: + self.SetAllFieldTypes(file_proto.package, desc_proto, scope) + + if file_proto.package: + desc_proto_prefix = _PrefixWithDot(file_proto.package) + else: + desc_proto_prefix = '' for desc_proto in file_proto.message_type: - desc = scope[desc_proto.name] + desc = self._GetTypeFromScope(desc_proto_prefix, desc_proto.name, scope) file_descriptor.message_types_by_name[desc_proto.name] = desc self.Add(file_proto) self._file_descriptors[file_proto.name] = file_descriptor @@ -260,10 +354,15 @@ class DescriptorPool(object): enums = [ self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, scope) for enum in desc_proto.enum_type] - fields = [self._MakeFieldDescriptor(field, desc_name, index) + fields = [self.MakeFieldDescriptor(field, desc_name, index) for index, field in enumerate(desc_proto.field)] - extensions = [self._MakeFieldDescriptor(extension, desc_name, True) - for index, extension in enumerate(desc_proto.extension)] + extensions = [ + self.MakeFieldDescriptor(extension, desc_name, index, is_extension=True) + for index, extension in enumerate(desc_proto.extension)] + oneofs = [ + descriptor.OneofDescriptor(desc.name, '.'.join((desc_name, desc.name)), + index, None, []) + for index, desc in enumerate(desc_proto.oneof_decl)] extension_ranges = [(r.start, r.end) for r in desc_proto.extension_range] if extension_ranges: is_extendable = True @@ -275,6 +374,7 @@ class DescriptorPool(object): filename=file_name, containing_type=None, fields=fields, + oneofs=oneofs, nested_types=nested, enum_types=enums, extensions=extensions, @@ -288,8 +388,13 @@ class DescriptorPool(object): nested.containing_type = desc for enum in desc.enum_types: enum.containing_type = desc - scope[desc_proto.name] = desc - scope['.' + desc_name] = desc + for field_index, field_desc in enumerate(desc_proto.field): + if field_desc.HasField('oneof_index'): + oneof_index = field_desc.oneof_index + oneofs[oneof_index].fields.append(fields[field_index]) + fields[field_index].containing_oneof = oneofs[oneof_index] + + scope[_PrefixWithDot(desc_name)] = desc self._descriptors[desc_name] = desc return desc @@ -327,13 +432,12 @@ class DescriptorPool(object): values=values, containing_type=containing_type, options=enum_proto.options) - scope[enum_proto.name] = desc scope['.%s' % enum_name] = desc self._enum_descriptors[enum_name] = desc return desc - def _MakeFieldDescriptor(self, field_proto, message_name, index, - is_extension=False): + def MakeFieldDescriptor(self, field_proto, message_name, index, + is_extension=False): """Creates a field descriptor from a FieldDescriptorProto. For message and enum type fields, this method will do a look up @@ -374,65 +478,93 @@ class DescriptorPool(object): extension_scope=None, options=field_proto.options) - def _SetFieldTypes(self, desc_proto, scope): - """Sets the field's type, cpp_type, message_type and enum_type. + def SetAllFieldTypes(self, package, desc_proto, scope): + """Sets all the descriptor's fields's types. + + This method also sets the containing types on any extensions. Args: + package: The current package of desc_proto. desc_proto: The message descriptor to update. scope: Enclosing scope of available types. """ - desc = scope[desc_proto.name] - for field_proto, field_desc in zip(desc_proto.field, desc.fields): - if field_proto.type_name: - type_name = field_proto.type_name - if type_name not in scope: - type_name = '.' + type_name - desc = scope[type_name] - else: - desc = None + package = _PrefixWithDot(package) - if not field_proto.HasField('type'): - if isinstance(desc, descriptor.Descriptor): - field_proto.type = descriptor.FieldDescriptor.TYPE_MESSAGE - else: - field_proto.type = descriptor.FieldDescriptor.TYPE_ENUM - - field_desc.cpp_type = descriptor.FieldDescriptor.ProtoTypeToCppProtoType( - field_proto.type) - - if (field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE - or field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP): - field_desc.message_type = desc - - if field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: - field_desc.enum_type = desc - - if field_proto.label == descriptor.FieldDescriptor.LABEL_REPEATED: - field_desc.has_default = False - field_desc.default_value = [] - elif field_proto.HasField('default_value'): - field_desc.has_default = True - if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or - field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT): - field_desc.default_value = float(field_proto.default_value) - elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING: - field_desc.default_value = field_proto.default_value - elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL: - field_desc.default_value = field_proto.default_value.lower() == 'true' - elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: - field_desc.default_value = field_desc.enum_type.values_by_name[ - field_proto.default_value].index - else: - field_desc.default_value = int(field_proto.default_value) - else: - field_desc.has_default = False - field_desc.default_value = None + main_desc = self._GetTypeFromScope(package, desc_proto.name, scope) - field_desc.type = field_proto.type + if package == '.': + nested_package = _PrefixWithDot(desc_proto.name) + else: + nested_package = '.'.join([package, desc_proto.name]) + + for field_proto, field_desc in zip(desc_proto.field, main_desc.fields): + self.SetFieldType(field_proto, field_desc, nested_package, scope) + + for extension_proto, extension_desc in ( + zip(desc_proto.extension, main_desc.extensions)): + extension_desc.containing_type = self._GetTypeFromScope( + nested_package, extension_proto.extendee, scope) + self.SetFieldType(extension_proto, extension_desc, nested_package, scope) for nested_type in desc_proto.nested_type: - self._SetFieldTypes(nested_type, scope) + self.SetAllFieldTypes(nested_package, nested_type, scope) + + def SetFieldType(self, field_proto, field_desc, package, scope): + """Sets the field's type, cpp_type, message_type and enum_type. + + Args: + field_proto: Data about the field in proto format. + field_desc: The descriptor to modiy. + package: The package the field's container is in. + scope: Enclosing scope of available types. + """ + if field_proto.type_name: + desc = self._GetTypeFromScope(package, field_proto.type_name, scope) + else: + desc = None + + if not field_proto.HasField('type'): + if isinstance(desc, descriptor.Descriptor): + field_proto.type = descriptor.FieldDescriptor.TYPE_MESSAGE + else: + field_proto.type = descriptor.FieldDescriptor.TYPE_ENUM + + field_desc.cpp_type = descriptor.FieldDescriptor.ProtoTypeToCppProtoType( + field_proto.type) + + if (field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE + or field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP): + field_desc.message_type = desc + + if field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: + field_desc.enum_type = desc + + if field_proto.label == descriptor.FieldDescriptor.LABEL_REPEATED: + field_desc.has_default_value = False + field_desc.default_value = [] + elif field_proto.HasField('default_value'): + field_desc.has_default_value = True + if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or + field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT): + field_desc.default_value = float(field_proto.default_value) + elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING: + field_desc.default_value = field_proto.default_value + elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL: + field_desc.default_value = field_proto.default_value.lower() == 'true' + elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: + field_desc.default_value = field_desc.enum_type.values_by_name[ + field_proto.default_value].index + elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES: + field_desc.default_value = text_encoding.CUnescape( + field_proto.default_value) + else: + field_desc.default_value = int(field_proto.default_value) + else: + field_desc.has_default_value = False + field_desc.default_value = None + + field_desc.type = field_proto.type def _MakeEnumValueDescriptor(self, value_proto, index): """Creates a enum value descriptor object from a enum value proto. @@ -452,76 +584,60 @@ class DescriptorPool(object): options=value_proto.options, type=None) - def _ExtractSymbols(self, desc_protos, package): + def _ExtractSymbols(self, descriptors): """Pulls out all the symbols from descriptor protos. Args: - desc_protos: The protos to extract symbols from. - package: The package containing the descriptor type. + descriptors: The messages to extract descriptors from. Yields: A two element tuple of the type name and descriptor object. """ - for desc_proto in desc_protos: - if package: - message_name = '.'.join((package, desc_proto.name)) - else: - message_name = desc_proto.name - message_desc = self.FindMessageTypeByName(message_name) - yield (message_name, message_desc) - for symbol in self._ExtractSymbols(desc_proto.nested_type, message_name): - yield symbol - for symbol in self._ExtractEnums(desc_proto.enum_type, message_name): + for desc in descriptors: + yield (_PrefixWithDot(desc.full_name), desc) + for symbol in self._ExtractSymbols(desc.nested_types): yield symbol + for enum in desc.enum_types: + yield (_PrefixWithDot(enum.full_name), enum) - def _ExtractEnums(self, enum_protos, package): - """Pulls out all the symbols from enum protos. + def _GetDeps(self, dependencies): + """Recursively finds dependencies for file protos. Args: - enum_protos: The protos to extract symbols from. - package: The package containing the enum type. + dependencies: The names of the files being depended on. Yields: - A two element tuple of the type name and enum descriptor object. + Each direct and indirect dependency. """ - for enum_proto in enum_protos: - if package: - enum_name = '.'.join((package, enum_proto.name)) - else: - enum_name = enum_proto.name - enum_desc = self.FindEnumTypeByName(enum_name) - yield (enum_name, enum_desc) + for dependency in dependencies: + dep_desc = self.FindFileByName(dependency) + yield dep_desc + for parent_dep in dep_desc.dependencies: + yield parent_dep - def _ExtractMessages(self, desc_protos): - """Pulls out all the message protos from descriptos. + def _GetTypeFromScope(self, package, type_name, scope): + """Finds a given type name in the current scope. Args: - desc_protos: The protos to extract symbols from. + package: The package the proto should be located in. + type_name: The name of the type to be found in the scope. + scope: Dict mapping short and full symbols to message and enum types. - Yields: - Descriptor protos. + Returns: + The descriptor for the requested type. """ + if type_name not in scope: + components = _PrefixWithDot(package).split('.') + while components: + possible_match = '.'.join(components + [type_name]) + if possible_match in scope: + type_name = possible_match + break + else: + components.pop(-1) + return scope[type_name] - for desc_proto in desc_protos: - yield desc_proto - for message in self._ExtractMessages(desc_proto.nested_type): - yield message - - def _GetDeps(self, file_proto): - """Recursively finds dependencies for file protos. - - Args: - file_proto: The proto to get dependencies from. - - Yields: - Each direct and indirect dependency. - """ - for dependency in file_proto.dependency: - dep_desc = self.FindFileByName(dependency) - dep_proto = descriptor_pb2.FileDescriptorProto.FromString( - dep_desc.serialized_pb) - yield dep_proto - for parent_dep in self._GetDeps(dep_proto): - yield parent_dep +def _PrefixWithDot(name): + return name if name.startswith('.') else '.%s' % name -- cgit v1.2.3