From 899460c9cb328e51b5da0ffe5d73e03c8f00dd15 Mon Sep 17 00:00:00 2001 From: Jie Luo Date: Mon, 10 Apr 2017 16:37:57 -0700 Subject: cherrypick descriptor_pool.FindFileContainingSymbol by extensions (#2962) * Use PyUnicode_AsEncodedString() instead of PyUnicode_AsEncodedObject() * Cherrypick the fix descriptor_pool.FindFileContainingSymbol by extensions. --- python/google/protobuf/descriptor_pool.py | 36 +++++++++++++++++++--- .../protobuf/internal/descriptor_pool_test.py | 9 ++++++ python/google/protobuf/pyext/message.cc | 2 +- .../protobuf/compiler/python/python_generator.cc | 6 ++-- 4 files changed, 45 insertions(+), 8 deletions(-) diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py index 7844575f..7bd2506b 100644 --- a/python/google/protobuf/descriptor_pool.py +++ b/python/google/protobuf/descriptor_pool.py @@ -127,6 +127,9 @@ class DescriptorPool(object): self._service_descriptors = {} self._file_descriptors = {} self._toplevel_extensions = {} + # TODO(jieluo): Remove _file_desc_by_toplevel_extension when + # FieldDescriptor.file is added in code gen. + self._file_desc_by_toplevel_extension = {} # We store extensions in two two-level mappings: The first key is the # descriptor of the message being extended, the second key is the extension # full name or its tag number. @@ -170,7 +173,7 @@ class DescriptorPool(object): raise TypeError('Expected instance of descriptor.Descriptor.') self._descriptors[desc.full_name] = desc - self.AddFileDescriptor(desc.file) + self._AddFileDescriptor(desc.file) def AddEnumDescriptor(self, enum_desc): """Adds an EnumDescriptor to the pool. @@ -185,7 +188,7 @@ class DescriptorPool(object): raise TypeError('Expected instance of descriptor.EnumDescriptor.') self._enum_descriptors[enum_desc.full_name] = enum_desc - self.AddFileDescriptor(enum_desc.file) + self._AddFileDescriptor(enum_desc.file) def AddServiceDescriptor(self, service_desc): """Adds a ServiceDescriptor to the pool. @@ -251,6 +254,23 @@ class DescriptorPool(object): file_desc: A FileDescriptor. """ + self._AddFileDescriptor(file_desc) + # TODO(jieluo): This is a temporary solution for FieldDescriptor.file. + # Remove it when FieldDescriptor.file is added in code gen. + for extension in file_desc.extensions_by_name.itervalues(): + self._file_desc_by_toplevel_extension[ + extension.full_name] = file_desc + + def _AddFileDescriptor(self, file_desc): + """Adds a FileDescriptor to the pool, non-recursively. + + If the FileDescriptor contains messages or enums, the caller must explicitly + register them. + + Args: + file_desc: A FileDescriptor. + """ + if not isinstance(file_desc, descriptor.FileDescriptor): raise TypeError('Expected instance of descriptor.FileDescriptor.') self._file_descriptors[file_desc.name] = file_desc @@ -313,12 +333,18 @@ class DescriptorPool(object): except KeyError: pass + try: + return self._file_desc_by_toplevel_extension[symbol] + except KeyError: + pass + # Try nested extensions inside a message. message_name, _, extension_name = symbol.rpartition('.') try: - scope = self.FindMessageTypeByName(message_name) - assert scope.extensions_by_name[extension_name] - return scope.file + message = self.FindMessageTypeByName(message_name) + assert message.extensions_by_name[extension_name] + return message.file + except KeyError: raise KeyError('Cannot find a file containing %s' % symbol) diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py index 2ba1d285..c1733a48 100644 --- a/python/google/protobuf/internal/descriptor_pool_test.py +++ b/python/google/protobuf/internal/descriptor_pool_test.py @@ -63,6 +63,9 @@ from google.protobuf import symbol_database class DescriptorPoolTest(unittest.TestCase): def setUp(self): + # TODO(jieluo): Should make the pool which is created by + # serialized_pb same with generated pool. + # TODO(jieluo): More test coverage for the generated pool. self.pool = descriptor_pool.DescriptorPool() self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString( factory_test1_pb2.DESCRIPTOR.serialized_pb) @@ -128,6 +131,12 @@ class DescriptorPoolTest(unittest.TestCase): self.assertEqual('google/protobuf/internal/factory_test2.proto', file_desc4.name) + # Tests the generated pool. + assert descriptor_pool.Default().FindFileContainingSymbol( + 'google.protobuf.python.internal.Factory2Message.one_more_field') + assert descriptor_pool.Default().FindFileContainingSymbol( + 'google.protobuf.python.internal.another_field') + def testFindFileContainingSymbolFailure(self): with self.assertRaises(KeyError): self.pool.FindFileContainingSymbol('Does not exist') diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc index c810b788..85aaa46f 100644 --- a/python/google/protobuf/pyext/message.cc +++ b/python/google/protobuf/pyext/message.cc @@ -779,7 +779,7 @@ PyObject* CheckString(PyObject* arg, const FieldDescriptor* descriptor) { encoded_string = arg; // Already encoded. Py_INCREF(encoded_string); } else { - encoded_string = PyUnicode_AsEncodedObject(arg, "utf-8", NULL); + encoded_string = PyUnicode_AsEncodedString(arg, "utf-8", NULL); } } else { // In this case field type is "bytes". diff --git a/src/google/protobuf/compiler/python/python_generator.cc b/src/google/protobuf/compiler/python/python_generator.cc index f83f155a..21a7e158 100644 --- a/src/google/protobuf/compiler/python/python_generator.cc +++ b/src/google/protobuf/compiler/python/python_generator.cc @@ -445,8 +445,6 @@ void Generator::PrintFileDescriptor() const { printer_->Outdent(); printer_->Print(")\n"); - printer_->Print("_sym_db.RegisterFileDescriptor($name$)\n", "name", - kDescriptorKey); printer_->Print("\n"); } @@ -999,6 +997,10 @@ void Generator::FixForeignFieldsInDescriptors() const { for (int i = 0; i < file_->extension_count(); ++i) { AddExtensionToFileDescriptor(*file_->extension(i)); } + // TODO(jieluo): Move this register to PrintFileDescriptor() when + // FieldDescriptor.file is added in generated file. + printer_->Print("_sym_db.RegisterFileDescriptor($name$)\n", "name", + kDescriptorKey); printer_->Print("\n"); } -- cgit v1.2.3