aboutsummaryrefslogtreecommitdiffhomepage
path: root/python/google/protobuf/internal/reflection_test.py
diff options
context:
space:
mode:
authorGravatar jieluo@google.com <jieluo@google.com@630680e5-0e50-0410-840e-4b1c322b438d>2014-08-12 21:10:30 +0000
committerGravatar jieluo@google.com <jieluo@google.com@630680e5-0e50-0410-840e-4b1c322b438d>2014-08-12 21:10:30 +0000
commitbde4a3254a7de58911941b0fbf38e9dd992de973 (patch)
tree02b151c2ec6e9be2e9d5ea0efc406aabe6958ae7 /python/google/protobuf/internal/reflection_test.py
parentd7339318a33c5f9e8b5dded4077223fbd4ebf229 (diff)
down integrate python opensource to svn
Diffstat (limited to 'python/google/protobuf/internal/reflection_test.py')
-rwxr-xr-xpython/google/protobuf/internal/reflection_test.py419
1 files changed, 341 insertions, 78 deletions
diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py
index ed286461..b3c414c7 100755
--- a/python/google/protobuf/internal/reflection_test.py
+++ b/python/google/protobuf/internal/reflection_test.py
@@ -37,11 +37,12 @@ pure-Python protocol compiler.
__author__ = 'robinson@google.com (Will Robinson)'
+import copy
import gc
import operator
import struct
-import unittest
+from google.apputils import basetest
from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_mset_pb2
from google.protobuf import unittest_pb2
@@ -49,6 +50,7 @@ from google.protobuf import descriptor_pb2
from google.protobuf import descriptor
from google.protobuf import message
from google.protobuf import reflection
+from google.protobuf import text_format
from google.protobuf.internal import api_implementation
from google.protobuf.internal import more_extensions_pb2
from google.protobuf.internal import more_messages_pb2
@@ -102,7 +104,7 @@ class _MiniDecoder(object):
return self._pos == len(self._bytes)
-class ReflectionTest(unittest.TestCase):
+class ReflectionTest(basetest.TestCase):
def assertListsEqual(self, values, others):
self.assertEqual(len(values), len(others))
@@ -533,7 +535,7 @@ class ReflectionTest(unittest.TestCase):
self.assertEqual(0.0, proto.optional_double)
self.assertEqual(False, proto.optional_bool)
self.assertEqual('', proto.optional_string)
- self.assertEqual('', proto.optional_bytes)
+ self.assertEqual(b'', proto.optional_bytes)
self.assertEqual(41, proto.default_int32)
self.assertEqual(42, proto.default_int64)
@@ -549,7 +551,7 @@ class ReflectionTest(unittest.TestCase):
self.assertEqual(52e3, proto.default_double)
self.assertEqual(True, proto.default_bool)
self.assertEqual('hello', proto.default_string)
- self.assertEqual('world', proto.default_bytes)
+ self.assertEqual(b'world', proto.default_bytes)
self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum)
self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum)
self.assertEqual(unittest_import_pb2.IMPORT_BAR,
@@ -566,6 +568,17 @@ class ReflectionTest(unittest.TestCase):
proto = unittest_pb2.TestAllTypes()
self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field')
+ def testClearRemovesChildren(self):
+ # Make sure there aren't any implementation bugs that are only partially
+ # clearing the message (which can happen in the more complex C++
+ # implementation which has parallel message lists).
+ proto = unittest_pb2.TestRequiredForeign()
+ for i in range(10):
+ proto.repeated_message.add()
+ proto2 = unittest_pb2.TestRequiredForeign()
+ proto.CopyFrom(proto2)
+ self.assertRaises(IndexError, lambda: proto.repeated_message[5])
+
def testDisallowedAssignments(self):
# It's illegal to assign values directly to repeated fields
# or to nonrepeated composite fields. Ensure that this fails.
@@ -594,6 +607,30 @@ class ReflectionTest(unittest.TestCase):
self.assertRaises(TypeError, setattr, proto, 'optional_string', 10)
self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10)
+ def testIntegerTypes(self):
+ def TestGetAndDeserialize(field_name, value, expected_type):
+ proto = unittest_pb2.TestAllTypes()
+ setattr(proto, field_name, value)
+ self.assertTrue(isinstance(getattr(proto, field_name), expected_type))
+ proto2 = unittest_pb2.TestAllTypes()
+ proto2.ParseFromString(proto.SerializeToString())
+ self.assertTrue(isinstance(getattr(proto2, field_name), expected_type))
+
+ TestGetAndDeserialize('optional_int32', 1, int)
+ TestGetAndDeserialize('optional_int32', 1 << 30, int)
+ TestGetAndDeserialize('optional_uint32', 1 << 30, int)
+ if struct.calcsize('L') == 4:
+ # Python only has signed ints, so 32-bit python can't fit an uint32
+ # in an int.
+ TestGetAndDeserialize('optional_uint32', 1 << 31, long)
+ else:
+ # 64-bit python can fit uint32 inside an int
+ TestGetAndDeserialize('optional_uint32', 1 << 31, int)
+ TestGetAndDeserialize('optional_int64', 1 << 30, long)
+ TestGetAndDeserialize('optional_int64', 1 << 60, long)
+ TestGetAndDeserialize('optional_uint64', 1 << 30, long)
+ TestGetAndDeserialize('optional_uint64', 1 << 60, long)
+
def testSingleScalarBoundsChecking(self):
def TestMinAndMaxIntegers(field_name, expected_min, expected_max):
pb = unittest_pb2.TestAllTypes()
@@ -613,29 +650,6 @@ class ReflectionTest(unittest.TestCase):
pb.optional_nested_enum = 1
self.assertEqual(1, pb.optional_nested_enum)
- # Invalid enum values.
- pb.optional_nested_enum = 0
- self.assertEqual(0, pb.optional_nested_enum)
-
- bytes_size_before = pb.ByteSize()
-
- pb.optional_nested_enum = 4
- self.assertEqual(4, pb.optional_nested_enum)
-
- pb.optional_nested_enum = 0
- self.assertEqual(0, pb.optional_nested_enum)
-
- # Make sure that setting the same enum field doesn't just add unknown
- # fields (but overwrites them).
- self.assertEqual(bytes_size_before, pb.ByteSize())
-
- # Is the invalid value preserved after serialization?
- serialized = pb.SerializeToString()
- pb2 = unittest_pb2.TestAllTypes()
- pb2.ParseFromString(serialized)
- self.assertEqual(0, pb2.optional_nested_enum)
- self.assertEqual(pb, pb2)
-
def testRepeatedScalarTypeSafety(self):
proto = unittest_pb2.TestAllTypes()
self.assertRaises(TypeError, proto.repeated_int32.append, 1.1)
@@ -749,9 +763,9 @@ class ReflectionTest(unittest.TestCase):
unittest_pb2.ForeignEnum.items())
proto = unittest_pb2.TestAllTypes()
- self.assertEqual(['FOO', 'BAR', 'BAZ'], proto.NestedEnum.keys())
- self.assertEqual([1, 2, 3], proto.NestedEnum.values())
- self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3)],
+ self.assertEqual(['FOO', 'BAR', 'BAZ', 'NEG'], proto.NestedEnum.keys())
+ self.assertEqual([1, 2, 3, -1], proto.NestedEnum.values())
+ self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)],
proto.NestedEnum.items())
def testRepeatedScalars(self):
@@ -1155,6 +1169,14 @@ class ReflectionTest(unittest.TestCase):
self.assertTrue(required is not extendee_proto.Extensions[extension])
self.assertTrue(not extendee_proto.HasExtension(extension))
+ def testRegisteredExtensions(self):
+ self.assertTrue('protobuf_unittest.optional_int32_extension' in
+ unittest_pb2.TestAllExtensions._extensions_by_name)
+ self.assertTrue(1 in unittest_pb2.TestAllExtensions._extensions_by_number)
+ # Make sure extensions haven't been registered into types that shouldn't
+ # have any.
+ self.assertEquals(0, len(unittest_pb2.TestAllTypes._extensions_by_name))
+
# If message A directly contains message B, and
# a.HasField('b') is currently False, then mutating any
# extension in B should change a.HasField('b') to True
@@ -1451,6 +1473,19 @@ class ReflectionTest(unittest.TestCase):
proto2 = unittest_pb2.TestAllExtensions()
self.assertRaises(TypeError, proto1.CopyFrom, proto2)
+ def testDeepCopy(self):
+ proto1 = unittest_pb2.TestAllTypes()
+ proto1.optional_int32 = 1
+ proto2 = copy.deepcopy(proto1)
+ self.assertEqual(1, proto2.optional_int32)
+
+ proto1.repeated_int32.append(2)
+ proto1.repeated_int32.append(3)
+ container = copy.deepcopy(proto1.repeated_int32)
+ self.assertEqual([2, 3], container)
+
+ # TODO(anuraag): Implement deepcopy for repeated composite / extension dict
+
def testClear(self):
proto = unittest_pb2.TestAllTypes()
# C++ implementation does not support lazy fields right now so leave it
@@ -1496,11 +1531,23 @@ class ReflectionTest(unittest.TestCase):
self.assertEqual(6, foreign.c)
nested.bb = 15
foreign.c = 16
- self.assertTrue(not proto.HasField('optional_nested_message'))
+ self.assertFalse(proto.HasField('optional_nested_message'))
self.assertEqual(0, proto.optional_nested_message.bb)
- self.assertTrue(not proto.HasField('optional_foreign_message'))
+ self.assertFalse(proto.HasField('optional_foreign_message'))
self.assertEqual(0, proto.optional_foreign_message.c)
+ def testOneOf(self):
+ proto = unittest_pb2.TestAllTypes()
+ proto.oneof_uint32 = 10
+ proto.oneof_nested_message.bb = 11
+ self.assertEqual(11, proto.oneof_nested_message.bb)
+ self.assertFalse(proto.HasField('oneof_uint32'))
+ nested = proto.oneof_nested_message
+ proto.oneof_string = 'abc'
+ self.assertEqual('abc', proto.oneof_string)
+ self.assertEqual(11, nested.bb)
+ self.assertFalse(proto.HasField('oneof_nested_message'))
+
def assertInitialized(self, proto):
self.assertTrue(proto.IsInitialized())
# Neither method should raise an exception.
@@ -1571,6 +1618,40 @@ class ReflectionTest(unittest.TestCase):
self.assertFalse(proto.IsInitialized(errors))
self.assertEqual(errors, ['a', 'b', 'c'])
+ @basetest.unittest.skipIf(
+ api_implementation.Type() != 'cpp' or api_implementation.Version() != 2,
+ 'Errors are only available from the most recent C++ implementation.')
+ def testFileDescriptorErrors(self):
+ file_name = 'test_file_descriptor_errors.proto'
+ package_name = 'test_file_descriptor_errors.proto'
+ file_descriptor_proto = descriptor_pb2.FileDescriptorProto()
+ file_descriptor_proto.name = file_name
+ file_descriptor_proto.package = package_name
+ m1 = file_descriptor_proto.message_type.add()
+ m1.name = 'msg1'
+ # Compiles the proto into the C++ descriptor pool
+ descriptor.FileDescriptor(
+ file_name,
+ package_name,
+ serialized_pb=file_descriptor_proto.SerializeToString())
+ # Add a FileDescriptorProto that has duplicate symbols
+ another_file_name = 'another_test_file_descriptor_errors.proto'
+ file_descriptor_proto.name = another_file_name
+ m2 = file_descriptor_proto.message_type.add()
+ m2.name = 'msg2'
+ with self.assertRaises(TypeError) as cm:
+ descriptor.FileDescriptor(
+ another_file_name,
+ package_name,
+ serialized_pb=file_descriptor_proto.SerializeToString())
+ self.assertTrue(hasattr(cm, 'exception'), '%s not raised' %
+ getattr(cm.expected, '__name__', cm.expected))
+ self.assertIn('test_file_descriptor_errors.proto', str(cm.exception))
+ # Error message will say something about this definition being a
+ # duplicate, though we don't check the message exactly to avoid a
+ # dependency on the C++ logging code.
+ self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception))
+
def testStringUTF8Encoding(self):
proto = unittest_pb2.TestAllTypes()
@@ -1588,17 +1669,15 @@ class ReflectionTest(unittest.TestCase):
proto.optional_string = str('Testing')
self.assertEqual(proto.optional_string, unicode('Testing'))
- if api_implementation.Type() == 'python':
- # Values of type 'str' are also accepted as long as they can be
- # encoded in UTF-8.
- self.assertEqual(type(proto.optional_string), str)
-
# Try to assign a 'str' value which contains bytes that aren't 7-bit ASCII.
self.assertRaises(ValueError,
- setattr, proto, 'optional_string', str('a\x80a'))
- # Assign a 'str' object which contains a UTF-8 encoded string.
- self.assertRaises(ValueError,
- setattr, proto, 'optional_string', 'Тест')
+ setattr, proto, 'optional_string', b'a\x80a')
+ if str is bytes: # PY2
+ # Assign a 'str' object which contains a UTF-8 encoded string.
+ self.assertRaises(ValueError,
+ setattr, proto, 'optional_string', 'Тест')
+ else:
+ proto.optional_string = 'Тест'
# No exception thrown.
proto.optional_string = 'abc'
@@ -1621,7 +1700,8 @@ class ReflectionTest(unittest.TestCase):
self.assertEqual(proto.ByteSize(), len(serialized))
raw = unittest_mset_pb2.RawMessageSet()
- raw.MergeFromString(serialized)
+ bytes_read = raw.MergeFromString(serialized)
+ self.assertEqual(len(serialized), bytes_read)
message2 = unittest_mset_pb2.TestMessageSetExtension2()
@@ -1632,7 +1712,8 @@ class ReflectionTest(unittest.TestCase):
# Check the actual bytes on the wire.
self.assertTrue(
raw.item[0].message.endswith(test_utf8_bytes))
- message2.MergeFromString(raw.item[0].message)
+ bytes_read = message2.MergeFromString(raw.item[0].message)
+ self.assertEqual(len(raw.item[0].message), bytes_read)
self.assertEqual(type(message2.str), unicode)
self.assertEqual(message2.str, test_utf8)
@@ -1643,17 +1724,22 @@ class ReflectionTest(unittest.TestCase):
# MergeFromString and thus has no way to throw the exception.
#
# The pure Python API always returns objects of type 'unicode' (UTF-8
- # encoded), or 'str' (in 7 bit ASCII).
- bytes = raw.item[0].message.replace(
- test_utf8_bytes, len(test_utf8_bytes) * '\xff')
+ # encoded), or 'bytes' (in 7 bit ASCII).
+ badbytes = raw.item[0].message.replace(
+ test_utf8_bytes, len(test_utf8_bytes) * b'\xff')
unicode_decode_failed = False
try:
- message2.MergeFromString(bytes)
- except UnicodeDecodeError as e:
+ message2.MergeFromString(badbytes)
+ except UnicodeDecodeError:
unicode_decode_failed = True
string_field = message2.str
- self.assertTrue(unicode_decode_failed or type(string_field) == str)
+ self.assertTrue(unicode_decode_failed or type(string_field) is bytes)
+
+ def testBytesInTextFormat(self):
+ proto = unittest_pb2.TestAllTypes(optional_bytes=b'\x00\x7f\x80\xff')
+ self.assertEqual(u'optional_bytes: "\\000\\177\\200\\377"\n',
+ unicode(proto))
def testEmptyNestedMessage(self):
proto = unittest_pb2.TestAllTypes()
@@ -1667,16 +1753,19 @@ class ReflectionTest(unittest.TestCase):
self.assertTrue(proto.HasField('optional_nested_message'))
proto = unittest_pb2.TestAllTypes()
- proto.optional_nested_message.MergeFromString('')
+ bytes_read = proto.optional_nested_message.MergeFromString(b'')
+ self.assertEqual(0, bytes_read)
self.assertTrue(proto.HasField('optional_nested_message'))
proto = unittest_pb2.TestAllTypes()
- proto.optional_nested_message.ParseFromString('')
+ proto.optional_nested_message.ParseFromString(b'')
self.assertTrue(proto.HasField('optional_nested_message'))
serialized = proto.SerializeToString()
proto2 = unittest_pb2.TestAllTypes()
- proto2.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ proto2.MergeFromString(serialized))
self.assertTrue(proto2.HasField('optional_nested_message'))
def testSetInParent(self):
@@ -1690,7 +1779,7 @@ class ReflectionTest(unittest.TestCase):
# into separate TestCase classes.
-class TestAllTypesEqualityTest(unittest.TestCase):
+class TestAllTypesEqualityTest(basetest.TestCase):
def setUp(self):
self.first_proto = unittest_pb2.TestAllTypes()
@@ -1706,7 +1795,7 @@ class TestAllTypesEqualityTest(unittest.TestCase):
self.assertEqual(self.first_proto, self.second_proto)
-class FullProtosEqualityTest(unittest.TestCase):
+class FullProtosEqualityTest(basetest.TestCase):
"""Equality tests using completely-full protos as a starting point."""
@@ -1792,7 +1881,7 @@ class FullProtosEqualityTest(unittest.TestCase):
self.assertEqual(self.first_proto, self.second_proto)
-class ExtensionEqualityTest(unittest.TestCase):
+class ExtensionEqualityTest(basetest.TestCase):
def testExtensionEquality(self):
first_proto = unittest_pb2.TestAllExtensions()
@@ -1825,7 +1914,7 @@ class ExtensionEqualityTest(unittest.TestCase):
self.assertEqual(first_proto, second_proto)
-class MutualRecursionEqualityTest(unittest.TestCase):
+class MutualRecursionEqualityTest(basetest.TestCase):
def testEqualityWithMutualRecursion(self):
first_proto = unittest_pb2.TestMutualRecursionA()
@@ -1837,7 +1926,7 @@ class MutualRecursionEqualityTest(unittest.TestCase):
self.assertEqual(first_proto, second_proto)
-class ByteSizeTest(unittest.TestCase):
+class ByteSizeTest(basetest.TestCase):
def setUp(self):
self.proto = unittest_pb2.TestAllTypes()
@@ -2133,14 +2222,16 @@ class ByteSizeTest(unittest.TestCase):
# * Handling of empty submessages (with and without "has"
# bits set).
-class SerializationTest(unittest.TestCase):
+class SerializationTest(basetest.TestCase):
def testSerializeEmtpyMessage(self):
first_proto = unittest_pb2.TestAllTypes()
second_proto = unittest_pb2.TestAllTypes()
serialized = first_proto.SerializeToString()
self.assertEqual(first_proto.ByteSize(), len(serialized))
- second_proto.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ second_proto.MergeFromString(serialized))
self.assertEqual(first_proto, second_proto)
def testSerializeAllFields(self):
@@ -2149,7 +2240,9 @@ class SerializationTest(unittest.TestCase):
test_util.SetAllFields(first_proto)
serialized = first_proto.SerializeToString()
self.assertEqual(first_proto.ByteSize(), len(serialized))
- second_proto.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ second_proto.MergeFromString(serialized))
self.assertEqual(first_proto, second_proto)
def testSerializeAllExtensions(self):
@@ -2157,7 +2250,19 @@ class SerializationTest(unittest.TestCase):
second_proto = unittest_pb2.TestAllExtensions()
test_util.SetAllExtensions(first_proto)
serialized = first_proto.SerializeToString()
- second_proto.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ second_proto.MergeFromString(serialized))
+ self.assertEqual(first_proto, second_proto)
+
+ def testSerializeWithOptionalGroup(self):
+ first_proto = unittest_pb2.TestAllTypes()
+ second_proto = unittest_pb2.TestAllTypes()
+ first_proto.optionalgroup.a = 242
+ serialized = first_proto.SerializeToString()
+ self.assertEqual(
+ len(serialized),
+ second_proto.MergeFromString(serialized))
self.assertEqual(first_proto, second_proto)
def testSerializeNegativeValues(self):
@@ -2249,7 +2354,9 @@ class SerializationTest(unittest.TestCase):
second_proto.optional_int32 = 100
second_proto.optional_nested_message.bb = 999
- second_proto.MergeFromString(serialized)
+ bytes_parsed = second_proto.MergeFromString(serialized)
+ self.assertEqual(len(serialized), bytes_parsed)
+
# Ensure that we append to repeated fields.
self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string))
# Ensure that we overwrite nonrepeatd scalars.
@@ -2274,20 +2381,28 @@ class SerializationTest(unittest.TestCase):
raw = unittest_mset_pb2.RawMessageSet()
self.assertEqual(False,
raw.DESCRIPTOR.GetOptions().message_set_wire_format)
- raw.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ raw.MergeFromString(serialized))
self.assertEqual(2, len(raw.item))
message1 = unittest_mset_pb2.TestMessageSetExtension1()
- message1.MergeFromString(raw.item[0].message)
+ self.assertEqual(
+ len(raw.item[0].message),
+ message1.MergeFromString(raw.item[0].message))
self.assertEqual(123, message1.i)
message2 = unittest_mset_pb2.TestMessageSetExtension2()
- message2.MergeFromString(raw.item[1].message)
+ self.assertEqual(
+ len(raw.item[1].message),
+ message2.MergeFromString(raw.item[1].message))
self.assertEqual('foo', message2.str)
# Deserialize using the MessageSet wire format.
proto2 = unittest_mset_pb2.TestMessageSet()
- proto2.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ proto2.MergeFromString(serialized))
self.assertEqual(123, proto2.Extensions[extension1].i)
self.assertEqual('foo', proto2.Extensions[extension2].str)
@@ -2327,7 +2442,9 @@ class SerializationTest(unittest.TestCase):
# Parse message using the message set wire format.
proto = unittest_mset_pb2.TestMessageSet()
- proto.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ proto.MergeFromString(serialized))
# Check that the message parsed well.
extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
@@ -2345,7 +2462,9 @@ class SerializationTest(unittest.TestCase):
proto2 = unittest_pb2.TestEmptyMessage()
# Parsing this message should succeed.
- proto2.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ proto2.MergeFromString(serialized))
# Now test with a int64 field set.
proto = unittest_pb2.TestAllTypes()
@@ -2355,7 +2474,9 @@ class SerializationTest(unittest.TestCase):
# unknown.
proto2 = unittest_pb2.TestEmptyMessage()
# Parsing this message should succeed.
- proto2.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ proto2.MergeFromString(serialized))
def _CheckRaises(self, exc_class, callable_obj, exception):
"""This method checks if the excpetion type and message are as expected."""
@@ -2406,11 +2527,15 @@ class SerializationTest(unittest.TestCase):
partial = proto.SerializePartialToString()
proto2 = unittest_pb2.TestRequired()
- proto2.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ proto2.MergeFromString(serialized))
self.assertEqual(1, proto2.a)
self.assertEqual(2, proto2.b)
self.assertEqual(3, proto2.c)
- proto2.ParseFromString(partial)
+ self.assertEqual(
+ len(partial),
+ proto2.MergeFromString(partial))
self.assertEqual(1, proto2.a)
self.assertEqual(2, proto2.b)
self.assertEqual(3, proto2.c)
@@ -2478,7 +2603,9 @@ class SerializationTest(unittest.TestCase):
second_proto.packed_double.extend([1.0, 2.0])
second_proto.packed_sint32.append(4)
- second_proto.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ second_proto.MergeFromString(serialized))
self.assertEqual([3, 1, 2], second_proto.packed_int32)
self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double)
self.assertEqual([4], second_proto.packed_sint32)
@@ -2511,7 +2638,10 @@ class SerializationTest(unittest.TestCase):
unpacked = unittest_pb2.TestUnpackedTypes()
test_util.SetAllUnpackedFields(unpacked)
packed = unittest_pb2.TestPackedTypes()
- packed.MergeFromString(unpacked.SerializeToString())
+ serialized = unpacked.SerializeToString()
+ self.assertEqual(
+ len(serialized),
+ packed.MergeFromString(serialized))
expected = unittest_pb2.TestPackedTypes()
test_util.SetAllPackedFields(expected)
self.assertEqual(expected, packed)
@@ -2520,7 +2650,10 @@ class SerializationTest(unittest.TestCase):
packed = unittest_pb2.TestPackedTypes()
test_util.SetAllPackedFields(packed)
unpacked = unittest_pb2.TestUnpackedTypes()
- unpacked.MergeFromString(packed.SerializeToString())
+ serialized = packed.SerializeToString()
+ self.assertEqual(
+ len(serialized),
+ unpacked.MergeFromString(serialized))
expected = unittest_pb2.TestUnpackedTypes()
test_util.SetAllUnpackedFields(expected)
self.assertEqual(expected, unpacked)
@@ -2572,7 +2705,7 @@ class SerializationTest(unittest.TestCase):
optional_int32=1,
optional_string='foo',
optional_bool=True,
- optional_bytes='bar',
+ optional_bytes=b'bar',
optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1),
optional_foreign_message=unittest_pb2.ForeignMessage(c=1),
optional_nested_enum=unittest_pb2.TestAllTypes.FOO,
@@ -2590,7 +2723,7 @@ class SerializationTest(unittest.TestCase):
self.assertEqual(1, proto.optional_int32)
self.assertEqual('foo', proto.optional_string)
self.assertEqual(True, proto.optional_bool)
- self.assertEqual('bar', proto.optional_bytes)
+ self.assertEqual(b'bar', proto.optional_bytes)
self.assertEqual(1, proto.optional_nested_message.bb)
self.assertEqual(1, proto.optional_foreign_message.c)
self.assertEqual(unittest_pb2.TestAllTypes.FOO,
@@ -2640,7 +2773,7 @@ class SerializationTest(unittest.TestCase):
self.assertEqual(3, proto.repeated_int32[2])
-class OptionsTest(unittest.TestCase):
+class OptionsTest(basetest.TestCase):
def testMessageOptions(self):
proto = unittest_mset_pb2.TestMessageSet()
@@ -2667,5 +2800,135 @@ class OptionsTest(unittest.TestCase):
+class ClassAPITest(basetest.TestCase):
+
+ def testMakeClassWithNestedDescriptor(self):
+ leaf_desc = descriptor.Descriptor('leaf', 'package.parent.child.leaf', '',
+ containing_type=None, fields=[],
+ nested_types=[], enum_types=[],
+ extensions=[])
+ child_desc = descriptor.Descriptor('child', 'package.parent.child', '',
+ containing_type=None, fields=[],
+ nested_types=[leaf_desc], enum_types=[],
+ extensions=[])
+ sibling_desc = descriptor.Descriptor('sibling', 'package.parent.sibling',
+ '', containing_type=None, fields=[],
+ nested_types=[], enum_types=[],
+ extensions=[])
+ parent_desc = descriptor.Descriptor('parent', 'package.parent', '',
+ containing_type=None, fields=[],
+ nested_types=[child_desc, sibling_desc],
+ enum_types=[], extensions=[])
+ message_class = reflection.MakeClass(parent_desc)
+ self.assertIn('child', message_class.__dict__)
+ self.assertIn('sibling', message_class.__dict__)
+ self.assertIn('leaf', message_class.child.__dict__)
+
+ def _GetSerializedFileDescriptor(self, name):
+ """Get a serialized representation of a test FileDescriptorProto.
+
+ Args:
+ name: All calls to this must use a unique message name, to avoid
+ collisions in the cpp descriptor pool.
+ Returns:
+ A string containing the serialized form of a test FileDescriptorProto.
+ """
+ file_descriptor_str = (
+ 'message_type {'
+ ' name: "' + name + '"'
+ ' field {'
+ ' name: "flat"'
+ ' number: 1'
+ ' label: LABEL_REPEATED'
+ ' type: TYPE_UINT32'
+ ' }'
+ ' field {'
+ ' name: "bar"'
+ ' number: 2'
+ ' label: LABEL_OPTIONAL'
+ ' type: TYPE_MESSAGE'
+ ' type_name: "Bar"'
+ ' }'
+ ' nested_type {'
+ ' name: "Bar"'
+ ' field {'
+ ' name: "baz"'
+ ' number: 3'
+ ' label: LABEL_OPTIONAL'
+ ' type: TYPE_MESSAGE'
+ ' type_name: "Baz"'
+ ' }'
+ ' nested_type {'
+ ' name: "Baz"'
+ ' enum_type {'
+ ' name: "deep_enum"'
+ ' value {'
+ ' name: "VALUE_A"'
+ ' number: 0'
+ ' }'
+ ' }'
+ ' field {'
+ ' name: "deep"'
+ ' number: 4'
+ ' label: LABEL_OPTIONAL'
+ ' type: TYPE_UINT32'
+ ' }'
+ ' }'
+ ' }'
+ '}')
+ file_descriptor = descriptor_pb2.FileDescriptorProto()
+ text_format.Merge(file_descriptor_str, file_descriptor)
+ return file_descriptor.SerializeToString()
+
+ def testParsingFlatClassWithExplicitClassDeclaration(self):
+ """Test that the generated class can parse a flat message."""
+ file_descriptor = descriptor_pb2.FileDescriptorProto()
+ file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('A'))
+ msg_descriptor = descriptor.MakeDescriptor(
+ file_descriptor.message_type[0])
+
+ class MessageClass(message.Message):
+ __metaclass__ = reflection.GeneratedProtocolMessageType
+ DESCRIPTOR = msg_descriptor
+ msg = MessageClass()
+ msg_str = (
+ 'flat: 0 '
+ 'flat: 1 '
+ 'flat: 2 ')
+ text_format.Merge(msg_str, msg)
+ self.assertEqual(msg.flat, [0, 1, 2])
+
+ def testParsingFlatClass(self):
+ """Test that the generated class can parse a flat message."""
+ file_descriptor = descriptor_pb2.FileDescriptorProto()
+ file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('B'))
+ msg_descriptor = descriptor.MakeDescriptor(
+ file_descriptor.message_type[0])
+ msg_class = reflection.MakeClass(msg_descriptor)
+ msg = msg_class()
+ msg_str = (
+ 'flat: 0 '
+ 'flat: 1 '
+ 'flat: 2 ')
+ text_format.Merge(msg_str, msg)
+ self.assertEqual(msg.flat, [0, 1, 2])
+
+ def testParsingNestedClass(self):
+ """Test that the generated class can parse a nested message."""
+ file_descriptor = descriptor_pb2.FileDescriptorProto()
+ file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C'))
+ msg_descriptor = descriptor.MakeDescriptor(
+ file_descriptor.message_type[0])
+ msg_class = reflection.MakeClass(msg_descriptor)
+ msg = msg_class()
+ msg_str = (
+ 'bar {'
+ ' baz {'
+ ' deep: 4'
+ ' }'
+ '}')
+ text_format.Merge(msg_str, msg)
+ self.assertEqual(msg.bar.baz.deep, 4)
+
if __name__ == '__main__':
- unittest.main()
+ basetest.main()