diff options
Diffstat (limited to 'python/google/protobuf/internal/descriptor_pool_test.py')
-rw-r--r-- | python/google/protobuf/internal/descriptor_pool_test.py | 46 |
1 files changed, 28 insertions, 18 deletions
diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py index c1733a48..6015e6f8 100644 --- a/python/google/protobuf/internal/descriptor_pool_test.py +++ b/python/google/protobuf/internal/descriptor_pool_test.py @@ -131,11 +131,19 @@ class DescriptorPoolTest(unittest.TestCase): self.assertEqual('google/protobuf/internal/factory_test2.proto', file_desc4.name) + file_desc5 = self.pool.FindFileContainingSymbol( + 'protobuf_unittest.TestService') + self.assertIsInstance(file_desc5, descriptor.FileDescriptor) + self.assertEqual('google/protobuf/unittest.proto', + file_desc5.name) + # Tests the generated pool. assert descriptor_pool.Default().FindFileContainingSymbol( 'google.protobuf.python.internal.Factory2Message.one_more_field') assert descriptor_pool.Default().FindFileContainingSymbol( 'google.protobuf.python.internal.another_field') + assert descriptor_pool.Default().FindFileContainingSymbol( + 'protobuf_unittest.TestService') def testFindFileContainingSymbolFailure(self): with self.assertRaises(KeyError): @@ -506,10 +514,10 @@ class MessageType(object): subtype.CheckType(test, desc, name, file_desc) for index, (name, field) in enumerate(self.field_list): - field.CheckField(test, desc, name, index) + field.CheckField(test, desc, name, index, file_desc) for index, (name, field) in enumerate(self.extensions): - field.CheckField(test, desc, name, index) + field.CheckField(test, desc, name, index, file_desc) class EnumField(object): @@ -519,7 +527,7 @@ class EnumField(object): self.type_name = type_name self.default_value = default_value - def CheckField(self, test, msg_desc, name, index): + def CheckField(self, test, msg_desc, name, index, file_desc): field_desc = msg_desc.fields_by_name[name] enum_desc = msg_desc.enum_types_by_name[self.type_name] test.assertEqual(name, field_desc.name) @@ -536,6 +544,7 @@ class EnumField(object): test.assertFalse(enum_desc.values_by_name[self.default_value].has_options) test.assertEqual(msg_desc, field_desc.containing_type) test.assertEqual(enum_desc, field_desc.enum_type) + test.assertEqual(file_desc, enum_desc.file) class MessageField(object): @@ -544,7 +553,7 @@ class MessageField(object): self.number = number self.type_name = type_name - def CheckField(self, test, msg_desc, name, index): + def CheckField(self, test, msg_desc, name, index, file_desc): field_desc = msg_desc.fields_by_name[name] field_type_desc = msg_desc.nested_types_by_name[self.type_name] test.assertEqual(name, field_desc.name) @@ -558,6 +567,7 @@ class MessageField(object): test.assertFalse(field_desc.has_default_value) test.assertEqual(msg_desc, field_desc.containing_type) test.assertEqual(field_type_desc, field_desc.message_type) + test.assertEqual(file_desc, field_desc.file) class StringField(object): @@ -566,7 +576,7 @@ class StringField(object): self.number = number self.default_value = default_value - def CheckField(self, test, msg_desc, name, index): + def CheckField(self, test, msg_desc, name, index, file_desc): field_desc = msg_desc.fields_by_name[name] test.assertEqual(name, field_desc.name) expected_field_full_name = '.'.join([msg_desc.full_name, name]) @@ -578,6 +588,7 @@ class StringField(object): field_desc.cpp_type) test.assertTrue(field_desc.has_default_value) test.assertEqual(self.default_value, field_desc.default_value) + test.assertEqual(file_desc, field_desc.file) class ExtensionField(object): @@ -586,7 +597,7 @@ class ExtensionField(object): self.number = number self.extended_type = extended_type - def CheckField(self, test, msg_desc, name, index): + def CheckField(self, test, msg_desc, name, index, file_desc): field_desc = msg_desc.extensions_by_name[name] test.assertEqual(name, field_desc.name) expected_field_full_name = '.'.join([msg_desc.full_name, name]) @@ -601,6 +612,7 @@ class ExtensionField(object): test.assertEqual(msg_desc, field_desc.extension_scope) test.assertEqual(msg_desc, field_desc.message_type) test.assertEqual(self.extended_type, field_desc.containing_type.name) + test.assertEqual(file_desc, field_desc.file) class AddDescriptorTest(unittest.TestCase): @@ -746,15 +758,10 @@ class AddDescriptorTest(unittest.TestCase): self.assertIs(options, file_descriptor.GetOptions()) -@unittest.skipIf( - api_implementation.Type() != 'cpp', - 'default_pool is only supported by the C++ implementation') class DefaultPoolTest(unittest.TestCase): def testFindMethods(self): - # pylint: disable=g-import-not-at-top - from google.protobuf.pyext import _message - pool = _message.default_pool + pool = descriptor_pool.Default() self.assertIs( pool.FindFileByName('google/protobuf/unittest.proto'), unittest_pb2.DESCRIPTOR) @@ -765,19 +772,22 @@ class DefaultPoolTest(unittest.TestCase): pool.FindFieldByName('protobuf_unittest.TestAllTypes.optional_int32'), unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name['optional_int32']) self.assertIs( - pool.FindExtensionByName('protobuf_unittest.optional_int32_extension'), - unittest_pb2.DESCRIPTOR.extensions_by_name['optional_int32_extension']) - self.assertIs( pool.FindEnumTypeByName('protobuf_unittest.ForeignEnum'), unittest_pb2.ForeignEnum.DESCRIPTOR) + if api_implementation.Type() != 'cpp': + self.skipTest('Only the C++ implementation correctly indexes all types') + self.assertIs( + pool.FindExtensionByName('protobuf_unittest.optional_int32_extension'), + unittest_pb2.DESCRIPTOR.extensions_by_name['optional_int32_extension']) self.assertIs( pool.FindOneofByName('protobuf_unittest.TestAllTypes.oneof_field'), unittest_pb2.TestAllTypes.DESCRIPTOR.oneofs_by_name['oneof_field']) + self.assertIs( + pool.FindServiceByName('protobuf_unittest.TestService'), + unittest_pb2.DESCRIPTOR.services_by_name['TestService']) def testAddFileDescriptor(self): - # pylint: disable=g-import-not-at-top - from google.protobuf.pyext import _message - pool = _message.default_pool + pool = descriptor_pool.Default() file_desc = descriptor_pb2.FileDescriptorProto(name='some/file.proto') pool.Add(file_desc) pool.AddSerializedFile(file_desc.SerializeToString()) |