aboutsummaryrefslogtreecommitdiffhomepage
path: root/python/google/protobuf/descriptor_pool.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/google/protobuf/descriptor_pool.py')
-rw-r--r--python/google/protobuf/descriptor_pool.py99
1 files changed, 84 insertions, 15 deletions
diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py
index fc3a7f44..7844575f 100644
--- a/python/google/protobuf/descriptor_pool.py
+++ b/python/google/protobuf/descriptor_pool.py
@@ -124,6 +124,7 @@ class DescriptorPool(object):
self._descriptor_db = descriptor_db
self._descriptors = {}
self._enum_descriptors = {}
+ self._service_descriptors = {}
self._file_descriptors = {}
self._toplevel_extensions = {}
# We store extensions in two two-level mappings: The first key is the
@@ -174,7 +175,7 @@ class DescriptorPool(object):
def AddEnumDescriptor(self, enum_desc):
"""Adds an EnumDescriptor to the pool.
- This method also registers the FileDescriptor associated with the message.
+ This method also registers the FileDescriptor associated with the enum.
Args:
enum_desc: An EnumDescriptor.
@@ -186,6 +187,18 @@ class DescriptorPool(object):
self._enum_descriptors[enum_desc.full_name] = enum_desc
self.AddFileDescriptor(enum_desc.file)
+ def AddServiceDescriptor(self, service_desc):
+ """Adds a ServiceDescriptor to the pool.
+
+ Args:
+ service_desc: A ServiceDescriptor.
+ """
+
+ if not isinstance(service_desc, descriptor.ServiceDescriptor):
+ raise TypeError('Expected instance of descriptor.ServiceDescriptor.')
+
+ self._service_descriptors[service_desc.full_name] = service_desc
+
def AddExtensionDescriptor(self, extension):
"""Adds a FieldDescriptor describing an extension to the pool.
@@ -252,7 +265,7 @@ class DescriptorPool(object):
A FileDescriptor for the named file.
Raises:
- KeyError: if the file can not be found in the pool.
+ KeyError: if the file cannot be found in the pool.
"""
try:
@@ -281,7 +294,7 @@ class DescriptorPool(object):
A FileDescriptor that contains the specified symbol.
Raises:
- KeyError: if the file can not be found in the pool.
+ KeyError: if the file cannot be found in the pool.
"""
symbol = _NormalizeFullyQualifiedName(symbol)
@@ -296,15 +309,18 @@ class DescriptorPool(object):
pass
try:
- file_proto = self._internal_db.FindFileContainingSymbol(symbol)
- except KeyError as error:
- if self._descriptor_db:
- file_proto = self._descriptor_db.FindFileContainingSymbol(symbol)
- else:
- raise error
- if not file_proto:
+ return self._FindFileContainingSymbolInDb(symbol)
+ except KeyError:
+ pass
+
+ # Try nested extensions inside a message.
+ message_name, _, extension_name = symbol.rpartition('.')
+ try:
+ scope = self.FindMessageTypeByName(message_name)
+ assert scope.extensions_by_name[extension_name]
+ return scope.file
+ except KeyError:
raise KeyError('Cannot find a file containing %s' % symbol)
- return self._ConvertFileProtoToFileDescriptor(file_proto)
def FindMessageTypeByName(self, full_name):
"""Loads the named descriptor from the pool.
@@ -314,11 +330,14 @@ class DescriptorPool(object):
Returns:
The descriptor for the named type.
+
+ Raises:
+ KeyError: if the message cannot be found in the pool.
"""
full_name = _NormalizeFullyQualifiedName(full_name)
if full_name not in self._descriptors:
- self.FindFileContainingSymbol(full_name)
+ self._FindFileContainingSymbolInDb(full_name)
return self._descriptors[full_name]
def FindEnumTypeByName(self, full_name):
@@ -329,11 +348,14 @@ class DescriptorPool(object):
Returns:
The enum descriptor for the named type.
+
+ Raises:
+ KeyError: if the enum cannot be found in the pool.
"""
full_name = _NormalizeFullyQualifiedName(full_name)
if full_name not in self._enum_descriptors:
- self.FindFileContainingSymbol(full_name)
+ self._FindFileContainingSymbolInDb(full_name)
return self._enum_descriptors[full_name]
def FindFieldByName(self, full_name):
@@ -344,6 +366,9 @@ class DescriptorPool(object):
Returns:
The field descriptor for the named field.
+
+ Raises:
+ KeyError: if the field cannot be found in the pool.
"""
full_name = _NormalizeFullyQualifiedName(full_name)
message_name, _, field_name = full_name.rpartition('.')
@@ -358,6 +383,9 @@ class DescriptorPool(object):
Returns:
A FieldDescriptor, describing the named extension.
+
+ Raises:
+ KeyError: if the extension cannot be found in the pool.
"""
full_name = _NormalizeFullyQualifiedName(full_name)
try:
@@ -374,7 +402,7 @@ class DescriptorPool(object):
scope = self.FindMessageTypeByName(message_name)
except KeyError:
# Some extensions are defined at file scope.
- scope = self.FindFileContainingSymbol(full_name)
+ scope = self._FindFileContainingSymbolInDb(full_name)
return scope.extensions_by_name[extension_name]
def FindExtensionByNumber(self, message_descriptor, number):
@@ -390,7 +418,7 @@ class DescriptorPool(object):
Returns:
A FieldDescriptor describing the extension.
- Raise:
+ Raises:
KeyError: when no extension with the given number is known for the
specified message.
"""
@@ -410,6 +438,46 @@ class DescriptorPool(object):
"""
return list(self._extensions_by_number[message_descriptor].values())
+ def FindServiceByName(self, full_name):
+ """Loads the named service descriptor from the pool.
+
+ Args:
+ full_name: The full name of the service descriptor to load.
+
+ Returns:
+ The service descriptor for the named service.
+
+ Raises:
+ KeyError: if the service cannot be found in the pool.
+ """
+ full_name = _NormalizeFullyQualifiedName(full_name)
+ if full_name not in self._service_descriptors:
+ self._FindFileContainingSymbolInDb(full_name)
+ return self._service_descriptors[full_name]
+
+ def _FindFileContainingSymbolInDb(self, symbol):
+ """Finds the file in descriptor DB containing the specified symbol.
+
+ Args:
+ symbol: The name of the symbol to search for.
+
+ Returns:
+ A FileDescriptor that contains the specified symbol.
+
+ Raises:
+ KeyError: if the file cannot be found in the descriptor database.
+ """
+ try:
+ file_proto = self._internal_db.FindFileContainingSymbol(symbol)
+ except KeyError as error:
+ if self._descriptor_db:
+ file_proto = self._descriptor_db.FindFileContainingSymbol(symbol)
+ else:
+ raise error
+ if not file_proto:
+ raise KeyError('Cannot find a file containing %s' % symbol)
+ return self._ConvertFileProtoToFileDescriptor(file_proto)
+
def _ConvertFileProtoToFileDescriptor(self, file_proto):
"""Creates a FileDescriptor from a proto or returns a cached copy.
@@ -804,6 +872,7 @@ class DescriptorPool(object):
methods=methods,
options=_OptionsOrNone(service_proto),
file=file_desc)
+ self._service_descriptors[service_name] = desc
return desc
def _MakeMethodDescriptor(self, method_proto, service_name, package, scope,