aboutsummaryrefslogtreecommitdiffhomepage
path: root/python/google/protobuf/descriptor_pool.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/google/protobuf/descriptor_pool.py')
-rw-r--r--python/google/protobuf/descriptor_pool.py99
1 files changed, 99 insertions, 0 deletions
diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py
index 5f43ee5f..28b7e843 100644
--- a/python/google/protobuf/descriptor_pool.py
+++ b/python/google/protobuf/descriptor_pool.py
@@ -57,6 +57,8 @@ directly instead of this class.
__author__ = 'matthewtoia@google.com (Matt Toia)'
+import collections
+
from google.protobuf import descriptor
from google.protobuf import descriptor_database
from google.protobuf import text_encoding
@@ -88,6 +90,14 @@ def _OptionsOrNone(descriptor_proto):
return None
+def _IsMessageSetExtension(field):
+ return (field.is_extension and
+ field.containing_type.has_options and
+ field.containing_type.GetOptions().message_set_wire_format and
+ field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
+ field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL)
+
+
class DescriptorPool(object):
"""A collection of protobufs dynamically constructed by descriptor protos."""
@@ -115,6 +125,12 @@ class DescriptorPool(object):
self._descriptors = {}
self._enum_descriptors = {}
self._file_descriptors = {}
+ self._toplevel_extensions = {}
+ # We store extensions in two two-level mappings: The first key is the
+ # descriptor of the message being extended, the second key is the extension
+ # full name or its tag number.
+ self._extensions_by_name = collections.defaultdict(dict)
+ self._extensions_by_number = collections.defaultdict(dict)
def Add(self, file_desc_proto):
"""Adds the FileDescriptorProto and its types to this pool.
@@ -170,6 +186,48 @@ class DescriptorPool(object):
self._enum_descriptors[enum_desc.full_name] = enum_desc
self.AddFileDescriptor(enum_desc.file)
+ def AddExtensionDescriptor(self, extension):
+ """Adds a FieldDescriptor describing an extension to the pool.
+
+ Args:
+ extension: A FieldDescriptor.
+
+ Raises:
+ AssertionError: when another extension with the same number extends the
+ same message.
+ TypeError: when the specified extension is not a
+ descriptor.FieldDescriptor.
+ """
+ if not (isinstance(extension, descriptor.FieldDescriptor) and
+ extension.is_extension):
+ raise TypeError('Expected an extension descriptor.')
+
+ if extension.extension_scope is None:
+ self._toplevel_extensions[extension.full_name] = extension
+
+ try:
+ existing_desc = self._extensions_by_number[
+ extension.containing_type][extension.number]
+ except KeyError:
+ pass
+ else:
+ if extension is not existing_desc:
+ raise AssertionError(
+ 'Extensions "%s" and "%s" both try to extend message type "%s" '
+ 'with field number %d.' %
+ (extension.full_name, existing_desc.full_name,
+ extension.containing_type.full_name, extension.number))
+
+ self._extensions_by_number[extension.containing_type][
+ extension.number] = extension
+ self._extensions_by_name[extension.containing_type][
+ extension.full_name] = extension
+
+ # Also register MessageSet extensions with the type name.
+ if _IsMessageSetExtension(extension):
+ self._extensions_by_name[extension.containing_type][
+ extension.message_type.full_name] = extension
+
def AddFileDescriptor(self, file_desc):
"""Adds a FileDescriptor to the pool, non-recursively.
@@ -302,6 +360,14 @@ class DescriptorPool(object):
A FieldDescriptor, describing the named extension.
"""
full_name = _NormalizeFullyQualifiedName(full_name)
+ try:
+ # The proto compiler does not give any link between the FileDescriptor
+ # and top-level extensions unless the FileDescriptorProto is added to
+ # the DescriptorDatabase, but this can impact memory usage.
+ # So we registered these extensions by name explicitly.
+ return self._toplevel_extensions[full_name]
+ except KeyError:
+ pass
message_name, _, extension_name = full_name.rpartition('.')
try:
# Most extensions are nested inside a message.
@@ -311,6 +377,39 @@ class DescriptorPool(object):
scope = self.FindFileContainingSymbol(full_name)
return scope.extensions_by_name[extension_name]
+ def FindExtensionByNumber(self, message_descriptor, number):
+ """Gets the extension of the specified message with the specified number.
+
+ Extensions have to be registered to this pool by calling
+ AddExtensionDescriptor.
+
+ Args:
+ message_descriptor: descriptor of the extended message.
+ number: integer, number of the extension field.
+
+ Returns:
+ A FieldDescriptor describing the extension.
+
+ Raise:
+ KeyError: when no extension with the given number is known for the
+ specified message.
+ """
+ return self._extensions_by_number[message_descriptor][number]
+
+ def FindAllExtensions(self, message_descriptor):
+ """Gets all the known extension of a given message.
+
+ Extensions have to be registered to this pool by calling
+ AddExtensionDescriptor.
+
+ Args:
+ message_descriptor: descriptor of the extended message.
+
+ Returns:
+ A list of FieldDescriptor describing the extensions.
+ """
+ return self._extensions_by_number[message_descriptor].values()
+
def _ConvertFileProtoToFileDescriptor(self, file_proto):
"""Creates a FileDescriptor from a proto or returns a cached copy.