diff options
Diffstat (limited to 'python/google/protobuf/internal/unknown_fields_test.py')
-rwxr-xr-x | python/google/protobuf/internal/unknown_fields_test.py | 200 |
1 files changed, 118 insertions, 82 deletions
diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py index 9685b8b4..8b7de2e7 100755 --- a/python/google/protobuf/internal/unknown_fields_test.py +++ b/python/google/protobuf/internal/unknown_fields_test.py @@ -36,7 +36,7 @@ __author__ = 'bohdank@google.com (Bohdan Koval)' try: - import unittest2 as unittest + import unittest2 as unittest #PY26 except ImportError: import unittest from google.protobuf import unittest_mset_pb2 @@ -47,16 +47,23 @@ from google.protobuf.internal import encoder from google.protobuf.internal import message_set_extensions_pb2 from google.protobuf.internal import missing_enum_values_pb2 from google.protobuf.internal import test_util +from google.protobuf.internal import testing_refleaks from google.protobuf.internal import type_checkers -def SkipIfCppImplementation(func): +BaseTestCase = testing_refleaks.BaseTestCase + + +# CheckUnknownField() cannot be used by the C++ implementation because +# some protect members are called. It is not a behavior difference +# for python and C++ implementation. +def SkipCheckUnknownFieldIfCppImplementation(func): return unittest.skipIf( api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, - 'C++ implementation does not expose unknown fields to Python')(func) + 'Addtional test for pure python involved protect members')(func) -class UnknownFieldsTest(unittest.TestCase): +class UnknownFieldsTest(BaseTestCase): def setUp(self): self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR @@ -73,11 +80,23 @@ class UnknownFieldsTest(unittest.TestCase): # stdout. self.assertTrue(data == self.all_fields_data) - def testSerializeProto3(self): - # Verify that proto3 doesn't preserve unknown fields. + def expectSerializeProto3(self, preserve): message = unittest_proto3_arena_pb2.TestEmptyMessage() message.ParseFromString(self.all_fields_data) - self.assertEqual(0, len(message.SerializeToString())) + if preserve: + self.assertEqual(self.all_fields_data, message.SerializeToString()) + else: + self.assertEqual(0, len(message.SerializeToString())) + + def testSerializeProto3(self): + # Verify that proto3 unknown fields behavior. + default_preserve = (api_implementation + .GetPythonProto3PreserveUnknownsDefault()) + self.expectSerializeProto3(default_preserve) + api_implementation.SetPythonProto3PreserveUnknownsDefault( + not default_preserve) + self.expectSerializeProto3(not default_preserve) + api_implementation.SetPythonProto3PreserveUnknownsDefault(default_preserve) def testByteSize(self): self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize()) @@ -119,8 +138,28 @@ class UnknownFieldsTest(unittest.TestCase): message.ParseFromString(self.all_fields.SerializeToString()) self.assertNotEqual(self.empty_message, message) - -class UnknownFieldsAccessorsTest(unittest.TestCase): + def testDiscardUnknownFields(self): + self.empty_message.DiscardUnknownFields() + self.assertEqual(b'', self.empty_message.SerializeToString()) + # Test message field and repeated message field. + message = unittest_pb2.TestAllTypes() + other_message = unittest_pb2.TestAllTypes() + other_message.optional_string = 'discard' + message.optional_nested_message.ParseFromString( + other_message.SerializeToString()) + message.repeated_nested_message.add().ParseFromString( + other_message.SerializeToString()) + self.assertNotEqual( + b'', message.optional_nested_message.SerializeToString()) + self.assertNotEqual( + b'', message.repeated_nested_message[0].SerializeToString()) + message.DiscardUnknownFields() + self.assertEqual(b'', message.optional_nested_message.SerializeToString()) + self.assertEqual( + b'', message.repeated_nested_message[0].SerializeToString()) + + +class UnknownFieldsAccessorsTest(BaseTestCase): def setUp(self): self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR @@ -129,60 +168,51 @@ class UnknownFieldsAccessorsTest(unittest.TestCase): self.all_fields_data = self.all_fields.SerializeToString() self.empty_message = unittest_pb2.TestEmptyMessage() self.empty_message.ParseFromString(self.all_fields_data) - if api_implementation.Type() != 'cpp': - # _unknown_fields is an implementation detail. - self.unknown_fields = self.empty_message._unknown_fields - # All the tests that use GetField() check an implementation detail of the - # Python implementation, which stores unknown fields as serialized strings. - # These tests are skipped by the C++ implementation: it's enough to check that - # the message is correctly serialized. + # CheckUnknownField() is an additional Pure Python check which checks + # a detail of unknown fields. It cannot be used by the C++ + # implementation because some protect members are called. + # The test is added for historical reasons. It is not necessary as + # serialized string is checked. - def GetField(self, name): + def CheckUnknownField(self, name, expected_value): field_descriptor = self.descriptor.fields_by_name[name] wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] field_tag = encoder.TagBytes(field_descriptor.number, wire_type) result_dict = {} - for tag_bytes, value in self.unknown_fields: + for tag_bytes, value in self.empty_message._unknown_fields: if tag_bytes == field_tag: decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0] decoder(value, 0, len(value), self.all_fields, result_dict) - return result_dict[field_descriptor] - - @SkipIfCppImplementation - def testEnum(self): - value = self.GetField('optional_nested_enum') - self.assertEqual(self.all_fields.optional_nested_enum, value) - - @SkipIfCppImplementation - def testRepeatedEnum(self): - value = self.GetField('repeated_nested_enum') - self.assertEqual(self.all_fields.repeated_nested_enum, value) - - @SkipIfCppImplementation - def testVarint(self): - value = self.GetField('optional_int32') - self.assertEqual(self.all_fields.optional_int32, value) - - @SkipIfCppImplementation - def testFixed32(self): - value = self.GetField('optional_fixed32') - self.assertEqual(self.all_fields.optional_fixed32, value) - - @SkipIfCppImplementation - def testFixed64(self): - value = self.GetField('optional_fixed64') - self.assertEqual(self.all_fields.optional_fixed64, value) - - @SkipIfCppImplementation - def testLengthDelimited(self): - value = self.GetField('optional_string') - self.assertEqual(self.all_fields.optional_string, value) - - @SkipIfCppImplementation - def testGroup(self): - value = self.GetField('optionalgroup') - self.assertEqual(self.all_fields.optionalgroup, value) + self.assertEqual(expected_value, result_dict[field_descriptor]) + + @SkipCheckUnknownFieldIfCppImplementation + def testCheckUnknownFieldValue(self): + # Test enum. + self.CheckUnknownField('optional_nested_enum', + self.all_fields.optional_nested_enum) + # Test repeated enum. + self.CheckUnknownField('repeated_nested_enum', + self.all_fields.repeated_nested_enum) + + # Test varint. + self.CheckUnknownField('optional_int32', + self.all_fields.optional_int32) + # Test fixed32. + self.CheckUnknownField('optional_fixed32', + self.all_fields.optional_fixed32) + + # Test fixed64. + self.CheckUnknownField('optional_fixed64', + self.all_fields.optional_fixed64) + + # Test lengthd elimited. + self.CheckUnknownField('optional_string', + self.all_fields.optional_string) + + # Test group. + self.CheckUnknownField('optionalgroup', + self.all_fields.optionalgroup) def testCopyFrom(self): message = unittest_pb2.TestEmptyMessage() @@ -221,45 +251,44 @@ class UnknownFieldsAccessorsTest(unittest.TestCase): self.assertEqual(message.SerializeToString(), self.all_fields_data) -class UnknownEnumValuesTest(unittest.TestCase): +class UnknownEnumValuesTest(BaseTestCase): def setUp(self): self.descriptor = missing_enum_values_pb2.TestEnumValues.DESCRIPTOR self.message = missing_enum_values_pb2.TestEnumValues() + # TestEnumValues.ZERO = 0, but does not exist in the other NestedEnum. self.message.optional_nested_enum = ( - missing_enum_values_pb2.TestEnumValues.ZERO) + missing_enum_values_pb2.TestEnumValues.ZERO) self.message.repeated_nested_enum.extend([ - missing_enum_values_pb2.TestEnumValues.ZERO, - missing_enum_values_pb2.TestEnumValues.ONE, - ]) + missing_enum_values_pb2.TestEnumValues.ZERO, + missing_enum_values_pb2.TestEnumValues.ONE, + ]) self.message.packed_nested_enum.extend([ - missing_enum_values_pb2.TestEnumValues.ZERO, - missing_enum_values_pb2.TestEnumValues.ONE, - ]) + missing_enum_values_pb2.TestEnumValues.ZERO, + missing_enum_values_pb2.TestEnumValues.ONE, + ]) self.message_data = self.message.SerializeToString() self.missing_message = missing_enum_values_pb2.TestMissingEnumValues() self.missing_message.ParseFromString(self.message_data) - if api_implementation.Type() != 'cpp': - # _unknown_fields is an implementation detail. - self.unknown_fields = self.missing_message._unknown_fields - # All the tests that use GetField() check an implementation detail of the - # Python implementation, which stores unknown fields as serialized strings. - # These tests are skipped by the C++ implementation: it's enough to check that - # the message is correctly serialized. + # CheckUnknownField() is an additional Pure Python check which checks + # a detail of unknown fields. It cannot be used by the C++ + # implementation because some protect members are called. + # The test is added for historical reasons. It is not necessary as + # serialized string is checked. - def GetField(self, name): + def CheckUnknownField(self, name, expected_value): field_descriptor = self.descriptor.fields_by_name[name] wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] field_tag = encoder.TagBytes(field_descriptor.number, wire_type) result_dict = {} - for tag_bytes, value in self.unknown_fields: + for tag_bytes, value in self.missing_message._unknown_fields: if tag_bytes == field_tag: decoder = missing_enum_values_pb2.TestEnumValues._decoders_by_tag[ - tag_bytes][0] + tag_bytes][0] decoder(value, 0, len(value), self.message, result_dict) - return result_dict[field_descriptor] + self.assertEqual(expected_value, result_dict[field_descriptor]) def testUnknownParseMismatchEnumValue(self): just_string = missing_enum_values_pb2.JustString() @@ -274,21 +303,28 @@ class UnknownEnumValuesTest(unittest.TestCase): # default value. self.assertEqual(missing.optional_nested_enum, 0) - @SkipIfCppImplementation def testUnknownEnumValue(self): self.assertFalse(self.missing_message.HasField('optional_nested_enum')) - value = self.GetField('optional_nested_enum') - self.assertEqual(self.message.optional_nested_enum, value) + self.assertEqual(self.missing_message.optional_nested_enum, 2) + # Clear does not do anything. + serialized = self.missing_message.SerializeToString() + self.missing_message.ClearField('optional_nested_enum') + self.assertEqual(self.missing_message.SerializeToString(), serialized) - @SkipIfCppImplementation def testUnknownRepeatedEnumValue(self): - value = self.GetField('repeated_nested_enum') - self.assertEqual(self.message.repeated_nested_enum, value) + self.assertEqual([], self.missing_message.repeated_nested_enum) - @SkipIfCppImplementation def testUnknownPackedEnumValue(self): - value = self.GetField('packed_nested_enum') - self.assertEqual(self.message.packed_nested_enum, value) + self.assertEqual([], self.missing_message.packed_nested_enum) + + @SkipCheckUnknownFieldIfCppImplementation + def testCheckUnknownFieldValueForEnum(self): + self.CheckUnknownField('optional_nested_enum', + self.message.optional_nested_enum) + self.CheckUnknownField('repeated_nested_enum', + self.message.repeated_nested_enum) + self.CheckUnknownField('packed_nested_enum', + self.message.packed_nested_enum) def testRoundTrip(self): new_message = missing_enum_values_pb2.TestEnumValues() |