diff options
Diffstat (limited to 'python/google/protobuf/internal/descriptor_pool_test.py')
-rw-r--r-- | python/google/protobuf/internal/descriptor_pool_test.py | 478 |
1 files changed, 385 insertions, 93 deletions
diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py index f1d6bf99..2cbf7813 100644 --- a/python/google/protobuf/internal/descriptor_pool_test.py +++ b/python/google/protobuf/internal/descriptor_pool_test.py @@ -34,12 +34,16 @@ __author__ = 'matthewtoia@google.com (Matt Toia)' +import copy import os +import sys +import warnings try: - import unittest2 as unittest + import unittest2 as unittest #PY26 except ImportError: import unittest + from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_import_public_pb2 from google.protobuf import unittest_pb2 @@ -49,7 +53,8 @@ from google.protobuf.internal import descriptor_pool_test1_pb2 from google.protobuf.internal import descriptor_pool_test2_pb2 from google.protobuf.internal import factory_test1_pb2 from google.protobuf.internal import factory_test2_pb2 -from google.protobuf.internal import test_util +from google.protobuf.internal import file_options_test_pb2 +from google.protobuf.internal import more_messages_pb2 from google.protobuf import descriptor from google.protobuf import descriptor_database from google.protobuf import descriptor_pool @@ -57,19 +62,8 @@ from google.protobuf import message_factory from google.protobuf import symbol_database -class DescriptorPoolTest(unittest.TestCase): - - def CreatePool(self): - return descriptor_pool.DescriptorPool() - def setUp(self): - self.pool = self.CreatePool() - self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString( - factory_test1_pb2.DESCRIPTOR.serialized_pb) - self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString( - factory_test2_pb2.DESCRIPTOR.serialized_pb) - self.pool.Add(self.factory_test1_fd) - self.pool.Add(self.factory_test2_fd) +class DescriptorPoolTestBase(object): def testFindFileByName(self): name1 = 'google/protobuf/internal/factory_test1.proto' @@ -107,6 +101,34 @@ class DescriptorPoolTest(unittest.TestCase): self.assertEqual('google.protobuf.python.internal', file_desc2.package) self.assertIn('Factory2Message', file_desc2.message_types_by_name) + # Tests top level extension. + file_desc3 = self.pool.FindFileContainingSymbol( + 'google.protobuf.python.internal.another_field') + self.assertIsInstance(file_desc3, descriptor.FileDescriptor) + self.assertEqual('google/protobuf/internal/factory_test2.proto', + file_desc3.name) + + # Tests nested extension inside a message. + file_desc4 = self.pool.FindFileContainingSymbol( + 'google.protobuf.python.internal.Factory2Message.one_more_field') + self.assertIsInstance(file_desc4, descriptor.FileDescriptor) + self.assertEqual('google/protobuf/internal/factory_test2.proto', + file_desc4.name) + + file_desc5 = self.pool.FindFileContainingSymbol( + 'protobuf_unittest.TestService') + self.assertIsInstance(file_desc5, descriptor.FileDescriptor) + self.assertEqual('google/protobuf/unittest.proto', + file_desc5.name) + + # Tests the generated pool. + assert descriptor_pool.Default().FindFileContainingSymbol( + 'google.protobuf.python.internal.Factory2Message.one_more_field') + assert descriptor_pool.Default().FindFileContainingSymbol( + 'google.protobuf.python.internal.another_field') + assert descriptor_pool.Default().FindFileContainingSymbol( + 'protobuf_unittest.TestService') + def testFindFileContainingSymbolFailure(self): with self.assertRaises(KeyError): self.pool.FindFileContainingSymbol('Does not exist') @@ -119,6 +141,7 @@ class DescriptorPoolTest(unittest.TestCase): self.assertEqual('google.protobuf.python.internal.Factory1Message', msg1.full_name) self.assertEqual(None, msg1.containing_type) + self.assertFalse(msg1.has_options) nested_msg1 = msg1.nested_types[0] self.assertEqual('NestedFactory1Message', nested_msg1.name) @@ -192,6 +215,27 @@ class DescriptorPoolTest(unittest.TestCase): msg2.fields_by_name[name].containing_oneof) self.assertIn(msg2.fields_by_name[name], msg2.oneofs[0].fields) + def testFindTypeErrors(self): + self.assertRaises(TypeError, self.pool.FindExtensionByNumber, '') + + # TODO(jieluo): Fix python to raise correct errors. + if api_implementation.Type() == 'cpp': + self.assertRaises(TypeError, self.pool.FindMethodByName, 0) + self.assertRaises(KeyError, self.pool.FindMethodByName, '') + error_type = TypeError + else: + error_type = AttributeError + self.assertRaises(error_type, self.pool.FindMessageTypeByName, 0) + self.assertRaises(error_type, self.pool.FindFieldByName, 0) + self.assertRaises(error_type, self.pool.FindExtensionByName, 0) + self.assertRaises(error_type, self.pool.FindEnumTypeByName, 0) + self.assertRaises(error_type, self.pool.FindOneofByName, 0) + self.assertRaises(error_type, self.pool.FindServiceByName, 0) + self.assertRaises(error_type, self.pool.FindFileContainingSymbol, 0) + if api_implementation.Type() == 'python': + error_type = KeyError + self.assertRaises(error_type, self.pool.FindFileByName, 0) + def testFindMessageTypeByNameFailure(self): with self.assertRaises(KeyError): self.pool.FindMessageTypeByName('Does not exist') @@ -202,6 +246,7 @@ class DescriptorPoolTest(unittest.TestCase): self.assertIsInstance(enum1, descriptor.EnumDescriptor) self.assertEqual(0, enum1.values_by_name['FACTORY_1_VALUE_0'].number) self.assertEqual(1, enum1.values_by_name['FACTORY_1_VALUE_1'].number) + self.assertFalse(enum1.has_options) nested_enum1 = self.pool.FindEnumTypeByName( 'google.protobuf.python.internal.Factory1Message.NestedFactory1Enum') @@ -230,14 +275,38 @@ class DescriptorPoolTest(unittest.TestCase): self.pool.FindEnumTypeByName('Does not exist') def testFindFieldByName(self): + if isinstance(self, SecondaryDescriptorFromDescriptorDB): + if api_implementation.Type() == 'cpp': + # TODO(jieluo): Fix cpp extension to find field correctly + # when descriptor pool is using an underlying database. + return field = self.pool.FindFieldByName( 'google.protobuf.python.internal.Factory1Message.list_value') self.assertEqual(field.name, 'list_value') self.assertEqual(field.label, field.LABEL_REPEATED) + self.assertFalse(field.has_options) + with self.assertRaises(KeyError): self.pool.FindFieldByName('Does not exist') + def testFindOneofByName(self): + if isinstance(self, SecondaryDescriptorFromDescriptorDB): + if api_implementation.Type() == 'cpp': + # TODO(jieluo): Fix cpp extension to find oneof correctly + # when descriptor pool is using an underlying database. + return + oneof = self.pool.FindOneofByName( + 'google.protobuf.python.internal.Factory2Message.oneof_field') + self.assertEqual(oneof.name, 'oneof_field') + with self.assertRaises(KeyError): + self.pool.FindOneofByName('Does not exist') + def testFindExtensionByName(self): + if isinstance(self, SecondaryDescriptorFromDescriptorDB): + if api_implementation.Type() == 'cpp': + # TODO(jieluo): Fix cpp extension to find extension correctly + # when descriptor pool is using an underlying database. + return # An extension defined in a message. extension = self.pool.FindExtensionByName( 'google.protobuf.python.internal.Factory2Message.one_more_field') @@ -250,6 +319,53 @@ class DescriptorPoolTest(unittest.TestCase): with self.assertRaises(KeyError): self.pool.FindFieldByName('Does not exist') + def testFindAllExtensions(self): + factory1_message = self.pool.FindMessageTypeByName( + 'google.protobuf.python.internal.Factory1Message') + factory2_message = self.pool.FindMessageTypeByName( + 'google.protobuf.python.internal.Factory2Message') + # An extension defined in a message. + one_more_field = factory2_message.extensions_by_name['one_more_field'] + self.pool.AddExtensionDescriptor(one_more_field) + # An extension defined at file scope. + factory_test2 = self.pool.FindFileByName( + 'google/protobuf/internal/factory_test2.proto') + another_field = factory_test2.extensions_by_name['another_field'] + self.pool.AddExtensionDescriptor(another_field) + + extensions = self.pool.FindAllExtensions(factory1_message) + expected_extension_numbers = set([one_more_field, another_field]) + self.assertEqual(expected_extension_numbers, set(extensions)) + # Verify that mutating the returned list does not affect the pool. + extensions.append('unexpected_element') + # Get the extensions again, the returned value does not contain the + # 'unexpected_element'. + extensions = self.pool.FindAllExtensions(factory1_message) + self.assertEqual(expected_extension_numbers, set(extensions)) + + def testFindExtensionByNumber(self): + factory1_message = self.pool.FindMessageTypeByName( + 'google.protobuf.python.internal.Factory1Message') + factory2_message = self.pool.FindMessageTypeByName( + 'google.protobuf.python.internal.Factory2Message') + # An extension defined in a message. + one_more_field = factory2_message.extensions_by_name['one_more_field'] + self.pool.AddExtensionDescriptor(one_more_field) + # An extension defined at file scope. + factory_test2 = self.pool.FindFileByName( + 'google/protobuf/internal/factory_test2.proto') + another_field = factory_test2.extensions_by_name['another_field'] + self.pool.AddExtensionDescriptor(another_field) + + # An extension defined in a message. + extension = self.pool.FindExtensionByNumber(factory1_message, 1001) + self.assertEqual(extension.name, 'one_more_field') + # An extension defined at file scope. + extension = self.pool.FindExtensionByNumber(factory1_message, 1002) + self.assertEqual(extension.name, 'another_field') + with self.assertRaises(KeyError): + extension = self.pool.FindExtensionByNumber(factory1_message, 1234567) + def testExtensionsAreNotFields(self): with self.assertRaises(KeyError): self.pool.FindFieldByName('google.protobuf.python.internal.another_field') @@ -260,6 +376,12 @@ class DescriptorPoolTest(unittest.TestCase): self.pool.FindExtensionByName( 'google.protobuf.python.internal.Factory1Message.list_value') + def testFindService(self): + service = self.pool.FindServiceByName('protobuf_unittest.TestService') + self.assertEqual(service.full_name, 'protobuf_unittest.TestService') + with self.assertRaises(KeyError): + self.pool.FindServiceByName('Does not exist') + def testUserDefinedDB(self): db = descriptor_database.DescriptorDatabase() self.pool = descriptor_pool.DescriptorPool(db) @@ -268,21 +390,17 @@ class DescriptorPoolTest(unittest.TestCase): self.testFindMessageTypeByName() def testAddSerializedFile(self): + if isinstance(self, SecondaryDescriptorFromDescriptorDB): + if api_implementation.Type() == 'cpp': + # Cpp extension cannot call Add on a DescriptorPool + # that uses a DescriptorDatabase. + # TODO(jieluo): Fix python and cpp extension diff. + return self.pool = descriptor_pool.DescriptorPool() self.pool.AddSerializedFile(self.factory_test1_fd.SerializeToString()) self.pool.AddSerializedFile(self.factory_test2_fd.SerializeToString()) self.testFindMessageTypeByName() - def testComplexNesting(self): - test1_desc = descriptor_pb2.FileDescriptorProto.FromString( - descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb) - test2_desc = descriptor_pb2.FileDescriptorProto.FromString( - descriptor_pool_test2_pb2.DESCRIPTOR.serialized_pb) - self.pool.Add(test1_desc) - self.pool.Add(test2_desc) - TEST1_FILE.CheckFile(self, self.pool) - TEST2_FILE.CheckFile(self, self.pool) - def testEnumDefaultValue(self): """Test the default value of enums which don't start at zero.""" @@ -301,6 +419,12 @@ class DescriptorPoolTest(unittest.TestCase): self.assertIs(file_descriptor, descriptor_pool_test1_pb2.DESCRIPTOR) _CheckDefaultValue(file_descriptor) + if isinstance(self, SecondaryDescriptorFromDescriptorDB): + if api_implementation.Type() == 'cpp': + # Cpp extension cannot call Add on a DescriptorPool + # that uses a DescriptorDatabase. + # TODO(jieluo): Fix python and cpp extension diff. + return # Then check the dynamic pool and its internal DescriptorDatabase. descriptor_proto = descriptor_pb2.FileDescriptorProto.FromString( descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb) @@ -348,26 +472,166 @@ class DescriptorPoolTest(unittest.TestCase): unittest_pb2.TestAllTypes.DESCRIPTOR.full_name)) _CheckDefaultValues(message_class()) + def testAddFileDescriptor(self): + if isinstance(self, SecondaryDescriptorFromDescriptorDB): + if api_implementation.Type() == 'cpp': + # Cpp extension cannot call Add on a DescriptorPool + # that uses a DescriptorDatabase. + # TODO(jieluo): Fix python and cpp extension diff. + return + file_desc = descriptor_pb2.FileDescriptorProto(name='some/file.proto') + self.pool.Add(file_desc) + self.pool.AddSerializedFile(file_desc.SerializeToString()) + + def testComplexNesting(self): + if isinstance(self, SecondaryDescriptorFromDescriptorDB): + if api_implementation.Type() == 'cpp': + # Cpp extension cannot call Add on a DescriptorPool + # that uses a DescriptorDatabase. + # TODO(jieluo): Fix python and cpp extension diff. + return + more_messages_desc = descriptor_pb2.FileDescriptorProto.FromString( + more_messages_pb2.DESCRIPTOR.serialized_pb) + test1_desc = descriptor_pb2.FileDescriptorProto.FromString( + descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb) + test2_desc = descriptor_pb2.FileDescriptorProto.FromString( + descriptor_pool_test2_pb2.DESCRIPTOR.serialized_pb) + self.pool.Add(more_messages_desc) + self.pool.Add(test1_desc) + self.pool.Add(test2_desc) + TEST1_FILE.CheckFile(self, self.pool) + TEST2_FILE.CheckFile(self, self.pool) + + def testConflictRegister(self): + if isinstance(self, SecondaryDescriptorFromDescriptorDB): + if api_implementation.Type() == 'cpp': + # Cpp extension cannot call Add on a DescriptorPool + # that uses a DescriptorDatabase. + # TODO(jieluo): Fix python and cpp extension diff. + return + unittest_fd = descriptor_pb2.FileDescriptorProto.FromString( + unittest_pb2.DESCRIPTOR.serialized_pb) + conflict_fd = copy.deepcopy(unittest_fd) + conflict_fd.name = 'other_file' + if api_implementation.Type() == 'cpp': + try: + self.pool.Add(unittest_fd) + self.pool.Add(conflict_fd) + except TypeError: + pass + else: + with warnings.catch_warnings(record=True) as w: + # Cause all warnings to always be triggered. + warnings.simplefilter('always') + pool = copy.deepcopy(self.pool) + # No warnings to add the same descriptors. + file_descriptor = unittest_pb2.DESCRIPTOR + pool.AddDescriptor( + file_descriptor.message_types_by_name['TestAllTypes']) + pool.AddEnumDescriptor( + file_descriptor.enum_types_by_name['ForeignEnum']) + pool.AddServiceDescriptor( + file_descriptor.services_by_name['TestService']) + pool.AddExtensionDescriptor( + file_descriptor.extensions_by_name['optional_int32_extension']) + self.assertEqual(len(w), 0) + # Check warnings for conflict descriptors with the same name. + pool.Add(unittest_fd) + pool.Add(conflict_fd) + pool.FindFileByName(unittest_fd.name) + pool.FindFileByName(conflict_fd.name) + self.assertTrue(len(w)) + self.assertIs(w[0].category, RuntimeWarning) + self.assertIn('Conflict register for file "other_file": ', + str(w[0].message)) + self.assertIn('already defined in file ' + '"google/protobuf/unittest.proto"', + str(w[0].message)) + + +class DefaultDescriptorPoolTest(DescriptorPoolTestBase, unittest.TestCase): + + def setUp(self): + self.pool = descriptor_pool.Default() + self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString( + factory_test1_pb2.DESCRIPTOR.serialized_pb) + self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString( + factory_test2_pb2.DESCRIPTOR.serialized_pb) + + def testFindMethods(self): + self.assertIs( + self.pool.FindFileByName('google/protobuf/unittest.proto'), + unittest_pb2.DESCRIPTOR) + self.assertIs( + self.pool.FindMessageTypeByName('protobuf_unittest.TestAllTypes'), + unittest_pb2.TestAllTypes.DESCRIPTOR) + self.assertIs( + self.pool.FindFieldByName( + 'protobuf_unittest.TestAllTypes.optional_int32'), + unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name['optional_int32']) + self.assertIs( + self.pool.FindEnumTypeByName('protobuf_unittest.ForeignEnum'), + unittest_pb2.ForeignEnum.DESCRIPTOR) + self.assertIs( + self.pool.FindExtensionByName( + 'protobuf_unittest.optional_int32_extension'), + unittest_pb2.DESCRIPTOR.extensions_by_name['optional_int32_extension']) + self.assertIs( + self.pool.FindOneofByName('protobuf_unittest.TestAllTypes.oneof_field'), + unittest_pb2.TestAllTypes.DESCRIPTOR.oneofs_by_name['oneof_field']) + self.assertIs( + self.pool.FindServiceByName('protobuf_unittest.TestService'), + unittest_pb2.DESCRIPTOR.services_by_name['TestService']) + + +class CreateDescriptorPoolTest(DescriptorPoolTestBase, unittest.TestCase): + + def setUp(self): + self.pool = descriptor_pool.DescriptorPool() + self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString( + factory_test1_pb2.DESCRIPTOR.serialized_pb) + self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString( + factory_test2_pb2.DESCRIPTOR.serialized_pb) + self.pool.Add(self.factory_test1_fd) + self.pool.Add(self.factory_test2_fd) + + self.pool.Add(descriptor_pb2.FileDescriptorProto.FromString( + unittest_import_public_pb2.DESCRIPTOR.serialized_pb)) + self.pool.Add(descriptor_pb2.FileDescriptorProto.FromString( + unittest_import_pb2.DESCRIPTOR.serialized_pb)) + self.pool.Add(descriptor_pb2.FileDescriptorProto.FromString( + unittest_pb2.DESCRIPTOR.serialized_pb)) -@unittest.skipIf(api_implementation.Type() != 'cpp', - 'explicit tests of the C++ implementation') -class CppDescriptorPoolTest(DescriptorPoolTest): - # TODO(amauryfa): remove when descriptor_pool.DescriptorPool() creates true - # C++ descriptor pool object for C++ implementation. - def CreatePool(self): - # pylint: disable=g-import-not-at-top - from google.protobuf.pyext import _message - return _message.DescriptorPool() +class SecondaryDescriptorFromDescriptorDB(DescriptorPoolTestBase, + unittest.TestCase): + + def setUp(self): + self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString( + factory_test1_pb2.DESCRIPTOR.serialized_pb) + self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString( + factory_test2_pb2.DESCRIPTOR.serialized_pb) + db = descriptor_database.DescriptorDatabase() + db.Add(self.factory_test1_fd) + db.Add(self.factory_test2_fd) + db.Add(descriptor_pb2.FileDescriptorProto.FromString( + unittest_import_public_pb2.DESCRIPTOR.serialized_pb)) + db.Add(descriptor_pb2.FileDescriptorProto.FromString( + unittest_import_pb2.DESCRIPTOR.serialized_pb)) + db.Add(descriptor_pb2.FileDescriptorProto.FromString( + unittest_pb2.DESCRIPTOR.serialized_pb)) + self.pool = descriptor_pool.DescriptorPool(descriptor_db=db) class ProtoFile(object): - def __init__(self, name, package, messages, dependencies=None): + def __init__(self, name, package, messages, dependencies=None, + public_dependencies=None): self.name = name self.package = package self.messages = messages self.dependencies = dependencies or [] + self.public_dependencies = public_dependencies or [] def CheckFile(self, test, pool): file_desc = pool.FindFileByName(self.name) @@ -375,6 +639,8 @@ class ProtoFile(object): test.assertEqual(self.package, file_desc.package) dependencies_names = [f.name for f in file_desc.dependencies] test.assertEqual(self.dependencies, dependencies_names) + public_dependencies_names = [f.name for f in file_desc.public_dependencies] + test.assertEqual(self.public_dependencies, public_dependencies_names) for name, msg_type in self.messages.items(): msg_type.CheckType(test, None, name, file_desc) @@ -426,10 +692,10 @@ class MessageType(object): subtype.CheckType(test, desc, name, file_desc) for index, (name, field) in enumerate(self.field_list): - field.CheckField(test, desc, name, index) + field.CheckField(test, desc, name, index, file_desc) for index, (name, field) in enumerate(self.extensions): - field.CheckField(test, desc, name, index) + field.CheckField(test, desc, name, index, file_desc) class EnumField(object): @@ -439,7 +705,7 @@ class EnumField(object): self.type_name = type_name self.default_value = default_value - def CheckField(self, test, msg_desc, name, index): + def CheckField(self, test, msg_desc, name, index, file_desc): field_desc = msg_desc.fields_by_name[name] enum_desc = msg_desc.enum_types_by_name[self.type_name] test.assertEqual(name, field_desc.name) @@ -453,8 +719,10 @@ class EnumField(object): test.assertTrue(field_desc.has_default_value) test.assertEqual(enum_desc.values_by_name[self.default_value].number, field_desc.default_value) + test.assertFalse(enum_desc.values_by_name[self.default_value].has_options) test.assertEqual(msg_desc, field_desc.containing_type) test.assertEqual(enum_desc, field_desc.enum_type) + test.assertEqual(file_desc, enum_desc.file) class MessageField(object): @@ -463,7 +731,7 @@ class MessageField(object): self.number = number self.type_name = type_name - def CheckField(self, test, msg_desc, name, index): + def CheckField(self, test, msg_desc, name, index, file_desc): field_desc = msg_desc.fields_by_name[name] field_type_desc = msg_desc.nested_types_by_name[self.type_name] test.assertEqual(name, field_desc.name) @@ -477,6 +745,12 @@ class MessageField(object): test.assertFalse(field_desc.has_default_value) test.assertEqual(msg_desc, field_desc.containing_type) test.assertEqual(field_type_desc, field_desc.message_type) + test.assertEqual(file_desc, field_desc.file) + # TODO(jieluo): Fix python and cpp extension diff for message field + # default value. + if api_implementation.Type() == 'cpp': + test.assertRaises( + NotImplementedError, getattr, field_desc, 'default_value') class StringField(object): @@ -485,7 +759,7 @@ class StringField(object): self.number = number self.default_value = default_value - def CheckField(self, test, msg_desc, name, index): + def CheckField(self, test, msg_desc, name, index, file_desc): field_desc = msg_desc.fields_by_name[name] test.assertEqual(name, field_desc.name) expected_field_full_name = '.'.join([msg_desc.full_name, name]) @@ -497,6 +771,7 @@ class StringField(object): field_desc.cpp_type) test.assertTrue(field_desc.has_default_value) test.assertEqual(self.default_value, field_desc.default_value) + test.assertEqual(file_desc, field_desc.file) class ExtensionField(object): @@ -505,7 +780,7 @@ class ExtensionField(object): self.number = number self.extended_type = extended_type - def CheckField(self, test, msg_desc, name, index): + def CheckField(self, test, msg_desc, name, index, file_desc): field_desc = msg_desc.extensions_by_name[name] test.assertEqual(name, field_desc.name) expected_field_full_name = '.'.join([msg_desc.full_name, name]) @@ -520,6 +795,7 @@ class ExtensionField(object): test.assertEqual(msg_desc, field_desc.extension_scope) test.assertEqual(msg_desc, field_desc.message_type) test.assertEqual(self.extended_type, field_desc.containing_type.name) + test.assertEqual(file_desc, field_desc.file) class AddDescriptorTest(unittest.TestCase): @@ -555,7 +831,7 @@ class AddDescriptorTest(unittest.TestCase): prefix + 'protobuf_unittest.TestAllTypes.NestedMessage').name) @unittest.skipIf(api_implementation.Type() == 'cpp', - 'With the cpp implementation, Add() must be called first') + 'With the cpp implementation, Add() must be called first') def testMessage(self): self._TestMessage('') self._TestMessage('.') @@ -591,13 +867,24 @@ class AddDescriptorTest(unittest.TestCase): prefix + 'protobuf_unittest.TestAllTypes.NestedEnum').name) @unittest.skipIf(api_implementation.Type() == 'cpp', - 'With the cpp implementation, Add() must be called first') + 'With the cpp implementation, Add() must be called first') def testEnum(self): self._TestEnum('') self._TestEnum('.') @unittest.skipIf(api_implementation.Type() == 'cpp', - 'With the cpp implementation, Add() must be called first') + 'With the cpp implementation, Add() must be called first') + def testService(self): + pool = descriptor_pool.DescriptorPool() + with self.assertRaises(KeyError): + pool.FindServiceByName('protobuf_unittest.TestService') + pool.AddServiceDescriptor(unittest_pb2._TESTSERVICE) + self.assertEqual( + 'protobuf_unittest.TestService', + pool.FindServiceByName('protobuf_unittest.TestService').full_name) + + @unittest.skipIf(api_implementation.Type() == 'cpp', + 'With the cpp implementation, Add() must be called first') def testFile(self): pool = descriptor_pool.DescriptorPool() pool.AddFileDescriptor(unittest_pb2.DESCRIPTOR) @@ -612,18 +899,9 @@ class AddDescriptorTest(unittest.TestCase): pool.FindFileContainingSymbol( 'protobuf_unittest.TestAllTypes') - def _GetDescriptorPoolClass(self): - # Test with both implementations of descriptor pools. - if api_implementation.Type() == 'cpp': - # pylint: disable=g-import-not-at-top - from google.protobuf.pyext import _message - return _message.DescriptorPool - else: - return descriptor_pool.DescriptorPool - def testEmptyDescriptorPool(self): - # Check that an empty DescriptorPool() contains no message. - pool = self._GetDescriptorPoolClass()() + # Check that an empty DescriptorPool() contains no messages. + pool = descriptor_pool.DescriptorPool() proto_file_name = descriptor_pb2.DESCRIPTOR.name self.assertRaises(KeyError, pool.FindFileByName, proto_file_name) # Add the above file to the pool @@ -635,7 +913,7 @@ class AddDescriptorTest(unittest.TestCase): def testCustomDescriptorPool(self): # Create a new pool, and add a file descriptor. - pool = self._GetDescriptorPoolClass()() + pool = descriptor_pool.DescriptorPool() file_desc = descriptor_pb2.FileDescriptorProto( name='some/file.proto', package='package') file_desc.message_type.add(name='Message') @@ -644,43 +922,55 @@ class AddDescriptorTest(unittest.TestCase): 'some/file.proto') self.assertEqual(pool.FindMessageTypeByName('package.Message').name, 'Message') - - -@unittest.skipIf( - api_implementation.Type() != 'cpp', - 'default_pool is only supported by the C++ implementation') -class DefaultPoolTest(unittest.TestCase): - - def testFindMethods(self): - # pylint: disable=g-import-not-at-top - from google.protobuf.pyext import _message - pool = _message.default_pool - self.assertIs( - pool.FindFileByName('google/protobuf/unittest.proto'), - unittest_pb2.DESCRIPTOR) - self.assertIs( - pool.FindMessageTypeByName('protobuf_unittest.TestAllTypes'), - unittest_pb2.TestAllTypes.DESCRIPTOR) - self.assertIs( - pool.FindFieldByName('protobuf_unittest.TestAllTypes.optional_int32'), - unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name['optional_int32']) - self.assertIs( - pool.FindExtensionByName('protobuf_unittest.optional_int32_extension'), - unittest_pb2.DESCRIPTOR.extensions_by_name['optional_int32_extension']) - self.assertIs( - pool.FindEnumTypeByName('protobuf_unittest.ForeignEnum'), - unittest_pb2.ForeignEnum.DESCRIPTOR) - self.assertIs( - pool.FindOneofByName('protobuf_unittest.TestAllTypes.oneof_field'), - unittest_pb2.TestAllTypes.DESCRIPTOR.oneofs_by_name['oneof_field']) - - def testAddFileDescriptor(self): - # pylint: disable=g-import-not-at-top - from google.protobuf.pyext import _message - pool = _message.default_pool - file_desc = descriptor_pb2.FileDescriptorProto(name='some/file.proto') - pool.Add(file_desc) - pool.AddSerializedFile(file_desc.SerializeToString()) + # Test no package + file_proto = descriptor_pb2.FileDescriptorProto( + name='some/filename/container.proto') + message_proto = file_proto.message_type.add( + name='TopMessage') + message_proto.field.add( + name='bb', + number=1, + type=descriptor_pb2.FieldDescriptorProto.TYPE_INT32, + label=descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL) + enum_proto = file_proto.enum_type.add(name='TopEnum') + enum_proto.value.add(name='FOREIGN_FOO', number=4) + file_proto.service.add(name='TopService') + pool = descriptor_pool.DescriptorPool() + pool.Add(file_proto) + self.assertEqual('TopMessage', + pool.FindMessageTypeByName('TopMessage').name) + self.assertEqual('TopEnum', pool.FindEnumTypeByName('TopEnum').name) + self.assertEqual('TopService', pool.FindServiceByName('TopService').name) + + def testFileDescriptorOptionsWithCustomDescriptorPool(self): + # Create a descriptor pool, and add a new FileDescriptorProto to it. + pool = descriptor_pool.DescriptorPool() + file_name = 'file_descriptor_options_with_custom_descriptor_pool.proto' + file_descriptor_proto = descriptor_pb2.FileDescriptorProto(name=file_name) + extension_id = file_options_test_pb2.foo_options + file_descriptor_proto.options.Extensions[extension_id].foo_name = 'foo' + pool.Add(file_descriptor_proto) + # The options set on the FileDescriptorProto should be available in the + # descriptor even if they contain extensions that cannot be deserialized + # using the pool. + file_descriptor = pool.FindFileByName(file_name) + options = file_descriptor.GetOptions() + self.assertEqual('foo', options.Extensions[extension_id].foo_name) + # The object returned by GetOptions() is cached. + self.assertIs(options, file_descriptor.GetOptions()) + + def testAddTypeError(self): + pool = descriptor_pool.DescriptorPool() + with self.assertRaises(TypeError): + pool.AddDescriptor(0) + with self.assertRaises(TypeError): + pool.AddEnumDescriptor(0) + with self.assertRaises(TypeError): + pool.AddServiceDescriptor(0) + with self.assertRaises(TypeError): + pool.AddExtensionDescriptor(0) + with self.assertRaises(TypeError): + pool.AddFileDescriptor(0) TEST1_FILE = ProtoFile( @@ -756,7 +1046,9 @@ TEST2_FILE = ProtoFile( ExtensionField(1001, 'DescriptorPoolTest1')), ]), }, - dependencies=['google/protobuf/internal/descriptor_pool_test1.proto']) + dependencies=['google/protobuf/internal/descriptor_pool_test1.proto', + 'google/protobuf/internal/more_messages.proto'], + public_dependencies=['google/protobuf/internal/more_messages.proto']) if __name__ == '__main__': |