aboutsummaryrefslogtreecommitdiffhomepage
path: root/python/google/protobuf/symbol_database.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/google/protobuf/symbol_database.py')
-rw-r--r--python/google/protobuf/symbol_database.py82
1 files changed, 33 insertions, 49 deletions
diff --git a/python/google/protobuf/symbol_database.py b/python/google/protobuf/symbol_database.py
index 87760f26..aa466abd 100644
--- a/python/google/protobuf/symbol_database.py
+++ b/python/google/protobuf/symbol_database.py
@@ -30,11 +30,9 @@
"""A database of Python protocol buffer generated symbols.
-SymbolDatabase makes it easy to create new instances of a registered type, given
-only the type's protocol buffer symbol name. Once all symbols are registered,
-they can be accessed using either the MessageFactory interface which
-SymbolDatabase exposes, or the DescriptorPool interface of the underlying
-pool.
+SymbolDatabase is the MessageFactory for messages generated at compile time,
+and makes it easy to create new instances of a registered type, given only the
+type's protocol buffer symbol name.
Example usage:
@@ -61,27 +59,17 @@ Example usage:
from google.protobuf import descriptor_pool
+from google.protobuf import message_factory
-class SymbolDatabase(object):
- """A database of Python generated symbols.
-
- SymbolDatabase also models message_factory.MessageFactory.
-
- The symbol database can be used to keep a global registry of all protocol
- buffer types used within a program.
- """
-
- def __init__(self, pool=None):
- """Constructor."""
-
- self._symbols = {}
- self._symbols_by_file = {}
- self.pool = pool or descriptor_pool.Default()
+class SymbolDatabase(message_factory.MessageFactory):
+ """A database of Python generated symbols."""
def RegisterMessage(self, message):
"""Registers the given message type in the local database.
+ Calls to GetSymbol() and GetMessages() will return messages registered here.
+
Args:
message: a message.Message, to be registered.
@@ -90,10 +78,7 @@ class SymbolDatabase(object):
"""
desc = message.DESCRIPTOR
- self._symbols[desc.full_name] = message
- if desc.file.name not in self._symbols_by_file:
- self._symbols_by_file[desc.file.name] = {}
- self._symbols_by_file[desc.file.name][desc.full_name] = message
+ self._classes[desc.full_name] = message
self.pool.AddDescriptor(desc)
return message
@@ -136,47 +121,46 @@ class SymbolDatabase(object):
KeyError: if the symbol could not be found.
"""
- return self._symbols[symbol]
-
- def GetPrototype(self, descriptor):
- """Builds a proto2 message class based on the passed in descriptor.
-
- Passing a descriptor with a fully qualified name matching a previous
- invocation will cause the same class to be returned.
-
- Args:
- descriptor: The descriptor to build from.
-
- Returns:
- A class describing the passed in descriptor.
- """
-
- return self.GetSymbol(descriptor.full_name)
+ return self._classes[symbol]
def GetMessages(self, files):
- """Gets all the messages from a specified file.
-
- This will find and resolve dependencies, failing if they are not registered
- in the symbol database.
+ # TODO(amauryfa): Fix the differences with MessageFactory.
+ """Gets all registered messages from a specified file.
+ Only messages already created and registered will be returned; (this is the
+ case for imported _pb2 modules)
+ But unlike MessageFactory, this version also returns nested messages.
Args:
files: The file names to extract messages from.
Returns:
- A dictionary mapping proto names to the message classes. This will include
- any dependent messages as well as any messages defined in the same file as
- a specified message.
+ A dictionary mapping proto names to the message classes.
Raises:
KeyError: if a file could not be found.
"""
+ def _GetAllMessageNames(desc):
+ """Walk a message Descriptor and recursively yields all message names."""
+ yield desc.full_name
+ for msg_desc in desc.nested_types:
+ for full_name in _GetAllMessageNames(msg_desc):
+ yield full_name
+
result = {}
- for f in files:
- result.update(self._symbols_by_file[f])
+ for file_name in files:
+ file_desc = self.pool.FindFileByName(file_name)
+ for msg_desc in file_desc.message_types_by_name.values():
+ for full_name in _GetAllMessageNames(msg_desc):
+ try:
+ result[full_name] = self._classes[full_name]
+ except KeyError:
+ # This descriptor has no registered class, skip it.
+ pass
return result
+
_DEFAULT = SymbolDatabase(pool=descriptor_pool.Default())