aboutsummaryrefslogtreecommitdiffhomepage
path: root/python/google/protobuf/descriptor_pool.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/google/protobuf/descriptor_pool.py')
-rw-r--r--python/google/protobuf/descriptor_pool.py43
1 files changed, 42 insertions, 1 deletions
diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py
index cb7146b6..8983f76f 100644
--- a/python/google/protobuf/descriptor_pool.py
+++ b/python/google/protobuf/descriptor_pool.py
@@ -58,6 +58,7 @@ directly instead of this class.
__author__ = 'matthewtoia@google.com (Matt Toia)'
import collections
+import warnings
from google.protobuf import descriptor
from google.protobuf import descriptor_database
@@ -127,12 +128,38 @@ class DescriptorPool(object):
self._service_descriptors = {}
self._file_descriptors = {}
self._toplevel_extensions = {}
+ # TODO(jieluo): Remove _file_desc_by_toplevel_extension after
+ # maybe year 2020 for compatibility issue (with 3.4.1 only).
+ self._file_desc_by_toplevel_extension = {}
# We store extensions in two two-level mappings: The first key is the
# descriptor of the message being extended, the second key is the extension
# full name or its tag number.
self._extensions_by_name = collections.defaultdict(dict)
self._extensions_by_number = collections.defaultdict(dict)
+ def _CheckConflictRegister(self, desc):
+ """Check if the descriptor name conflicts with another of the same name.
+
+ Args:
+ desc: Descriptor of a message, enum, service or extension.
+ """
+ desc_name = desc.full_name
+ for register, descriptor_type in [
+ (self._descriptors, descriptor.Descriptor),
+ (self._enum_descriptors, descriptor.EnumDescriptor),
+ (self._service_descriptors, descriptor.ServiceDescriptor),
+ (self._toplevel_extensions, descriptor.FieldDescriptor)]:
+ if desc_name in register:
+ file_name = register[desc_name].file.name
+ if not isinstance(desc, descriptor_type) or (
+ file_name != desc.file.name):
+ warn_msg = ('Conflict register for file "' + desc.file.name +
+ '": ' + desc_name +
+ ' is already defined in file "' +
+ file_name + '"')
+ warnings.warn(warn_msg, RuntimeWarning)
+ return
+
def Add(self, file_desc_proto):
"""Adds the FileDescriptorProto and its types to this pool.
@@ -169,6 +196,8 @@ class DescriptorPool(object):
if not isinstance(desc, descriptor.Descriptor):
raise TypeError('Expected instance of descriptor.Descriptor.')
+ self._CheckConflictRegister(desc)
+
self._descriptors[desc.full_name] = desc
self._AddFileDescriptor(desc.file)
@@ -184,6 +213,7 @@ class DescriptorPool(object):
if not isinstance(enum_desc, descriptor.EnumDescriptor):
raise TypeError('Expected instance of descriptor.EnumDescriptor.')
+ self._CheckConflictRegister(enum_desc)
self._enum_descriptors[enum_desc.full_name] = enum_desc
self._AddFileDescriptor(enum_desc.file)
@@ -197,6 +227,7 @@ class DescriptorPool(object):
if not isinstance(service_desc, descriptor.ServiceDescriptor):
raise TypeError('Expected instance of descriptor.ServiceDescriptor.')
+ self._CheckConflictRegister(service_desc)
self._service_descriptors[service_desc.full_name] = service_desc
def AddExtensionDescriptor(self, extension):
@@ -216,6 +247,7 @@ class DescriptorPool(object):
raise TypeError('Expected an extension descriptor.')
if extension.extension_scope is None:
+ self._CheckConflictRegister(extension)
self._toplevel_extensions[extension.full_name] = extension
try:
@@ -252,6 +284,12 @@ class DescriptorPool(object):
"""
self._AddFileDescriptor(file_desc)
+ # TODO(jieluo): This is a temporary solution for FieldDescriptor.file.
+ # FieldDescriptor.file is added in code gen. Remove this solution after
+ # maybe 2020 for compatibility reason (with 3.4.1 only).
+ for extension in file_desc.extensions_by_name.values():
+ self._file_desc_by_toplevel_extension[
+ extension.full_name] = file_desc
def _AddFileDescriptor(self, file_desc):
"""Adds a FileDescriptor to the pool, non-recursively.
@@ -331,7 +369,7 @@ class DescriptorPool(object):
pass
try:
- return self._toplevel_extensions[symbol].file
+ return self._file_desc_by_toplevel_extension[symbol]
except KeyError:
pass
@@ -680,6 +718,7 @@ class DescriptorPool(object):
fields[field_index].containing_oneof = oneofs[oneof_index]
scope[_PrefixWithDot(desc_name)] = desc
+ self._CheckConflictRegister(desc)
self._descriptors[desc_name] = desc
return desc
@@ -718,6 +757,7 @@ class DescriptorPool(object):
containing_type=containing_type,
options=_OptionsOrNone(enum_proto))
scope['.%s' % enum_name] = desc
+ self._CheckConflictRegister(desc)
self._enum_descriptors[enum_name] = desc
return desc
@@ -914,6 +954,7 @@ class DescriptorPool(object):
methods=methods,
options=_OptionsOrNone(service_proto),
file=file_desc)
+ self._CheckConflictRegister(desc)
self._service_descriptors[service_name] = desc
return desc