aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jie Luo <anandolee@gmail.com>2017-04-10 16:37:57 -0700
committerGravatar Feng Xiao <xfxyjwf@gmail.com>2017-04-10 16:37:57 -0700
commit899460c9cb328e51b5da0ffe5d73e03c8f00dd15 (patch)
tree882fc6e219956b77f3f30c8560b10c017fa7f61f
parente91caa1f197ec1bd48a713c435c47e6f3948a1fb (diff)
cherrypick descriptor_pool.FindFileContainingSymbol by extensions (#2962)
* Use PyUnicode_AsEncodedString() instead of PyUnicode_AsEncodedObject() * Cherrypick the fix descriptor_pool.FindFileContainingSymbol by extensions.
-rw-r--r--python/google/protobuf/descriptor_pool.py36
-rw-r--r--python/google/protobuf/internal/descriptor_pool_test.py9
-rw-r--r--python/google/protobuf/pyext/message.cc2
-rw-r--r--src/google/protobuf/compiler/python/python_generator.cc6
4 files changed, 45 insertions, 8 deletions
diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py
index 7844575f..7bd2506b 100644
--- a/python/google/protobuf/descriptor_pool.py
+++ b/python/google/protobuf/descriptor_pool.py
@@ -127,6 +127,9 @@ class DescriptorPool(object):
self._service_descriptors = {}
self._file_descriptors = {}
self._toplevel_extensions = {}
+ # TODO(jieluo): Remove _file_desc_by_toplevel_extension when
+ # FieldDescriptor.file is added in code gen.
+ self._file_desc_by_toplevel_extension = {}
# We store extensions in two two-level mappings: The first key is the
# descriptor of the message being extended, the second key is the extension
# full name or its tag number.
@@ -170,7 +173,7 @@ class DescriptorPool(object):
raise TypeError('Expected instance of descriptor.Descriptor.')
self._descriptors[desc.full_name] = desc
- self.AddFileDescriptor(desc.file)
+ self._AddFileDescriptor(desc.file)
def AddEnumDescriptor(self, enum_desc):
"""Adds an EnumDescriptor to the pool.
@@ -185,7 +188,7 @@ class DescriptorPool(object):
raise TypeError('Expected instance of descriptor.EnumDescriptor.')
self._enum_descriptors[enum_desc.full_name] = enum_desc
- self.AddFileDescriptor(enum_desc.file)
+ self._AddFileDescriptor(enum_desc.file)
def AddServiceDescriptor(self, service_desc):
"""Adds a ServiceDescriptor to the pool.
@@ -251,6 +254,23 @@ class DescriptorPool(object):
file_desc: A FileDescriptor.
"""
+ self._AddFileDescriptor(file_desc)
+ # TODO(jieluo): This is a temporary solution for FieldDescriptor.file.
+ # Remove it when FieldDescriptor.file is added in code gen.
+ for extension in file_desc.extensions_by_name.itervalues():
+ self._file_desc_by_toplevel_extension[
+ extension.full_name] = file_desc
+
+ def _AddFileDescriptor(self, file_desc):
+ """Adds a FileDescriptor to the pool, non-recursively.
+
+ If the FileDescriptor contains messages or enums, the caller must explicitly
+ register them.
+
+ Args:
+ file_desc: A FileDescriptor.
+ """
+
if not isinstance(file_desc, descriptor.FileDescriptor):
raise TypeError('Expected instance of descriptor.FileDescriptor.')
self._file_descriptors[file_desc.name] = file_desc
@@ -313,12 +333,18 @@ class DescriptorPool(object):
except KeyError:
pass
+ try:
+ return self._file_desc_by_toplevel_extension[symbol]
+ except KeyError:
+ pass
+
# Try nested extensions inside a message.
message_name, _, extension_name = symbol.rpartition('.')
try:
- scope = self.FindMessageTypeByName(message_name)
- assert scope.extensions_by_name[extension_name]
- return scope.file
+ message = self.FindMessageTypeByName(message_name)
+ assert message.extensions_by_name[extension_name]
+ return message.file
+
except KeyError:
raise KeyError('Cannot find a file containing %s' % symbol)
diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py
index 2ba1d285..c1733a48 100644
--- a/python/google/protobuf/internal/descriptor_pool_test.py
+++ b/python/google/protobuf/internal/descriptor_pool_test.py
@@ -63,6 +63,9 @@ from google.protobuf import symbol_database
class DescriptorPoolTest(unittest.TestCase):
def setUp(self):
+ # TODO(jieluo): Should make the pool which is created by
+ # serialized_pb same with generated pool.
+ # TODO(jieluo): More test coverage for the generated pool.
self.pool = descriptor_pool.DescriptorPool()
self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString(
factory_test1_pb2.DESCRIPTOR.serialized_pb)
@@ -128,6 +131,12 @@ class DescriptorPoolTest(unittest.TestCase):
self.assertEqual('google/protobuf/internal/factory_test2.proto',
file_desc4.name)
+ # Tests the generated pool.
+ assert descriptor_pool.Default().FindFileContainingSymbol(
+ 'google.protobuf.python.internal.Factory2Message.one_more_field')
+ assert descriptor_pool.Default().FindFileContainingSymbol(
+ 'google.protobuf.python.internal.another_field')
+
def testFindFileContainingSymbolFailure(self):
with self.assertRaises(KeyError):
self.pool.FindFileContainingSymbol('Does not exist')
diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc
index c810b788..85aaa46f 100644
--- a/python/google/protobuf/pyext/message.cc
+++ b/python/google/protobuf/pyext/message.cc
@@ -779,7 +779,7 @@ PyObject* CheckString(PyObject* arg, const FieldDescriptor* descriptor) {
encoded_string = arg; // Already encoded.
Py_INCREF(encoded_string);
} else {
- encoded_string = PyUnicode_AsEncodedObject(arg, "utf-8", NULL);
+ encoded_string = PyUnicode_AsEncodedString(arg, "utf-8", NULL);
}
} else {
// In this case field type is "bytes".
diff --git a/src/google/protobuf/compiler/python/python_generator.cc b/src/google/protobuf/compiler/python/python_generator.cc
index f83f155a..21a7e158 100644
--- a/src/google/protobuf/compiler/python/python_generator.cc
+++ b/src/google/protobuf/compiler/python/python_generator.cc
@@ -445,8 +445,6 @@ void Generator::PrintFileDescriptor() const {
printer_->Outdent();
printer_->Print(")\n");
- printer_->Print("_sym_db.RegisterFileDescriptor($name$)\n", "name",
- kDescriptorKey);
printer_->Print("\n");
}
@@ -999,6 +997,10 @@ void Generator::FixForeignFieldsInDescriptors() const {
for (int i = 0; i < file_->extension_count(); ++i) {
AddExtensionToFileDescriptor(*file_->extension(i));
}
+ // TODO(jieluo): Move this register to PrintFileDescriptor() when
+ // FieldDescriptor.file is added in generated file.
+ printer_->Print("_sym_db.RegisterFileDescriptor($name$)\n", "name",
+ kDescriptorKey);
printer_->Print("\n");
}