aboutsummaryrefslogtreecommitdiffhomepage
path: root/python/google/protobuf/internal
diff options
context:
space:
mode:
authorGravatar Jisi Liu <jisi.liu@gmail.com>2015-02-25 16:39:11 -0800
committerGravatar Jisi Liu <jisi.liu@gmail.com>2015-02-25 16:39:11 -0800
commitada65567852b96fdb4d070c0c3f86ca7b77824f9 (patch)
treea506994ce921ace3e6f88ca130a17af7f85c3d0f /python/google/protobuf/internal
parent581be24606a925d038f382dc4c86256e2d29e001 (diff)
Down integrate from Google internal.
Change-Id: I34d301133eea9c6f3a822c47d1f91e136fd33145
Diffstat (limited to 'python/google/protobuf/internal')
-rw-r--r--python/google/protobuf/internal/api_implementation.cc14
-rwxr-xr-xpython/google/protobuf/internal/api_implementation.py38
-rw-r--r--python/google/protobuf/internal/api_implementation_default_test.py63
-rwxr-xr-xpython/google/protobuf/internal/containers.py37
-rwxr-xr-xpython/google/protobuf/internal/decoder.py3
-rwxr-xr-xpython/google/protobuf/internal/descriptor_test.py176
-rwxr-xr-xpython/google/protobuf/internal/message_test.py782
-rwxr-xr-xpython/google/protobuf/internal/python_message.py81
-rwxr-xr-xpython/google/protobuf/internal/reflection_test.py21
-rwxr-xr-xpython/google/protobuf/internal/test_util.py170
-rwxr-xr-xpython/google/protobuf/internal/text_format_test.py557
-rwxr-xr-xpython/google/protobuf/internal/type_checkers.py8
-rwxr-xr-xpython/google/protobuf/internal/unknown_fields_test.py120
13 files changed, 1338 insertions, 732 deletions
diff --git a/python/google/protobuf/internal/api_implementation.cc b/python/google/protobuf/internal/api_implementation.cc
index 83db40b1..6db12e8d 100644
--- a/python/google/protobuf/internal/api_implementation.cc
+++ b/python/google/protobuf/internal/api_implementation.cc
@@ -50,10 +50,7 @@ namespace python {
// and
// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION=2
#ifdef PYTHON_PROTO2_CPP_IMPL_V1
-#if PY_MAJOR_VERSION >= 3
-#error "PYTHON_PROTO2_CPP_IMPL_V1 is not supported under Python 3."
-#endif
-static int kImplVersion = 1;
+#error "PYTHON_PROTO2_CPP_IMPL_V1 is no longer supported."
#else
#ifdef PYTHON_PROTO2_CPP_IMPL_V2
static int kImplVersion = 2;
@@ -62,14 +59,7 @@ static int kImplVersion = 2;
static int kImplVersion = 0;
#else
-// The defaults are set here. Python 3 uses the fast C++ APIv2 by default.
-// Python 2 still uses the Python version by default until some compatibility
-// issues can be worked around.
-#if PY_MAJOR_VERSION >= 3
-static int kImplVersion = 2;
-#else
-static int kImplVersion = 0;
-#endif
+static int kImplVersion = -1; // -1 means "Unspecified by compiler flags".
#endif // PYTHON_PROTO2_PYTHON_IMPL
#endif // PYTHON_PROTO2_CPP_IMPL_V2
diff --git a/python/google/protobuf/internal/api_implementation.py b/python/google/protobuf/internal/api_implementation.py
index f7926c16..8ba4357c 100755
--- a/python/google/protobuf/internal/api_implementation.py
+++ b/python/google/protobuf/internal/api_implementation.py
@@ -40,14 +40,33 @@ try:
# The compile-time constants in the _api_implementation module can be used to
# switch to a certain implementation of the Python API at build time.
_api_version = _api_implementation.api_version
- del _api_implementation
+ _proto_extension_modules_exist_in_build = True
except ImportError:
- _api_version = 0
+ _api_version = -1 # Unspecified by compiler flags.
+ _proto_extension_modules_exist_in_build = False
+
+if _api_version == 1:
+ raise ValueError('api_version=1 is no longer supported.')
+if _api_version < 0: # Still unspecified?
+ try:
+ # The presence of this module in a build allows the proto implementation to
+ # be upgraded merely via build deps rather than a compiler flag or the
+ # runtime environment variable.
+ # pylint: disable=g-import-not-at-top
+ from google.protobuf import _use_fast_cpp_protos
+ # Work around a known issue in the classic bootstrap .par import hook.
+ if not _use_fast_cpp_protos:
+ raise ImportError('_use_fast_cpp_protos import succeeded but was None')
+ del _use_fast_cpp_protos
+ _api_version = 2
+ except ImportError:
+ if _proto_extension_modules_exist_in_build:
+ if sys.version_info[0] >= 3: # Python 3 defaults to C++ impl v2.
+ _api_version = 2
+ # TODO(b/17427486): Make Python 2 default to C++ impl v2.
_default_implementation_type = (
- 'python' if _api_version == 0 else 'cpp')
-_default_version_str = (
- '1' if _api_version <= 1 else '2')
+ 'python' if _api_version <= 0 else 'cpp')
# This environment variable can be used to switch to a certain implementation
# of the Python API, overriding the compile-time constants in the
@@ -64,13 +83,12 @@ if _implementation_type != 'python':
# _api_implementation module. Right now only 1 and 2 are valid values. Any other
# value will be ignored.
_implementation_version_str = os.getenv(
- 'PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION',
- _default_version_str)
+ 'PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION', '2')
-if _implementation_version_str not in ('1', '2'):
+if _implementation_version_str != '2':
raise ValueError(
- "unsupported PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION: '" +
- _implementation_version_str + "' (supported versions: 1, 2)"
+ 'unsupported PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION: "' +
+ _implementation_version_str + '" (supported versions: 2)'
)
_implementation_version = int(_implementation_version_str)
diff --git a/python/google/protobuf/internal/api_implementation_default_test.py b/python/google/protobuf/internal/api_implementation_default_test.py
deleted file mode 100644
index 78d5cf23..00000000
--- a/python/google/protobuf/internal/api_implementation_default_test.py
+++ /dev/null
@@ -1,63 +0,0 @@
-#! /usr/bin/python
-#
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# https://developers.google.com/protocol-buffers/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Test that the api_implementation defaults are what we expect."""
-
-import os
-import sys
-# Clear environment implementation settings before the google3 imports.
-os.environ.pop('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION', None)
-os.environ.pop('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION', None)
-
-# pylint: disable=g-import-not-at-top
-from google.apputils import basetest
-from google.protobuf.internal import api_implementation
-
-
-class ApiImplementationDefaultTest(basetest.TestCase):
-
- if sys.version_info.major <= 2:
-
- def testThatPythonIsTheDefault(self):
- """If -DPYTHON_PROTO_*IMPL* was given at build time, this may fail."""
- self.assertEqual('python', api_implementation.Type())
-
- else:
-
- def testThatCppApiV2IsTheDefault(self):
- """If -DPYTHON_PROTO_*IMPL* was given at build time, this may fail."""
- self.assertEqual('cpp', api_implementation.Type())
- self.assertEqual(2, api_implementation.Version())
-
-
-if __name__ == '__main__':
- basetest.main()
diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py
index 20bfa857..d976f9e1 100755
--- a/python/google/protobuf/internal/containers.py
+++ b/python/google/protobuf/internal/containers.py
@@ -41,7 +41,6 @@ are:
__author__ = 'petar@google.com (Petar Petrov)'
-
class BaseContainer(object):
"""Base container class."""
@@ -119,15 +118,23 @@ class RepeatedScalarFieldContainer(BaseContainer):
self._message_listener.Modified()
def extend(self, elem_seq):
- """Extends by appending the given sequence. Similar to list.extend()."""
- if not elem_seq:
- return
+ """Extends by appending the given iterable. Similar to list.extend()."""
- new_values = []
- for elem in elem_seq:
- new_values.append(self._type_checker.CheckValue(elem))
- self._values.extend(new_values)
- self._message_listener.Modified()
+ if elem_seq is None:
+ return
+ try:
+ elem_seq_iter = iter(elem_seq)
+ except TypeError:
+ if not elem_seq:
+ # silently ignore falsy inputs :-/.
+ # TODO(ptucker): Deprecate this behavior. b/18413862
+ return
+ raise
+
+ new_values = [self._type_checker.CheckValue(elem) for elem in elem_seq_iter]
+ if new_values:
+ self._values.extend(new_values)
+ self._message_listener.Modified()
def MergeFrom(self, other):
"""Appends the contents of another repeated field of the same type to this
@@ -141,6 +148,12 @@ class RepeatedScalarFieldContainer(BaseContainer):
self._values.remove(elem)
self._message_listener.Modified()
+ def pop(self, key=-1):
+ """Removes and returns an item at a given index. Similar to list.pop()."""
+ value = self._values[key]
+ self.__delitem__(key)
+ return value
+
def __setitem__(self, key, value):
"""Sets the item on the specified position."""
if isinstance(key, slice): # PY3
@@ -245,6 +258,12 @@ class RepeatedCompositeFieldContainer(BaseContainer):
self._values.remove(elem)
self._message_listener.Modified()
+ def pop(self, key=-1):
+ """Removes and returns an item at a given index. Similar to list.pop()."""
+ value = self._values[key]
+ self.__delitem__(key)
+ return value
+
def __getslice__(self, start, stop):
"""Retrieves the subset of items from between the specified indices."""
return self._values[start:stop]
diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py
index a4b90608..0f500606 100755
--- a/python/google/protobuf/internal/decoder.py
+++ b/python/google/protobuf/internal/decoder.py
@@ -621,9 +621,6 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
if value is None:
value = field_dict.setdefault(key, new_default(message))
while 1:
- value = field_dict.get(key)
- if value is None:
- value = field_dict.setdefault(key, new_default(message))
# Read length.
(size, pos) = local_DecodeVarint(buffer, pos)
new_pos = pos + size
diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py
index 3924f21a..50c4dbba 100755
--- a/python/google/protobuf/internal/descriptor_test.py
+++ b/python/google/protobuf/internal/descriptor_test.py
@@ -34,12 +34,16 @@
__author__ = 'robinson@google.com (Will Robinson)'
+import sys
+
from google.apputils import basetest
from google.protobuf import unittest_custom_options_pb2
from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_pb2
from google.protobuf import descriptor_pb2
+from google.protobuf.internal import api_implementation
from google.protobuf import descriptor
+from google.protobuf import symbol_database
from google.protobuf import text_format
@@ -51,41 +55,28 @@ name: 'TestEmptyMessage'
class DescriptorTest(basetest.TestCase):
def setUp(self):
- self.my_file = descriptor.FileDescriptor(
+ file_proto = descriptor_pb2.FileDescriptorProto(
name='some/filename/some.proto',
- package='protobuf_unittest'
- )
- self.my_enum = descriptor.EnumDescriptor(
- name='ForeignEnum',
- full_name='protobuf_unittest.ForeignEnum',
- filename=None,
- file=self.my_file,
- values=[
- descriptor.EnumValueDescriptor(name='FOREIGN_FOO', index=0, number=4),
- descriptor.EnumValueDescriptor(name='FOREIGN_BAR', index=1, number=5),
- descriptor.EnumValueDescriptor(name='FOREIGN_BAZ', index=2, number=6),
- ])
- self.my_message = descriptor.Descriptor(
- name='NestedMessage',
- full_name='protobuf_unittest.TestAllTypes.NestedMessage',
- filename=None,
- file=self.my_file,
- containing_type=None,
- fields=[
- descriptor.FieldDescriptor(
- name='bb',
- full_name='protobuf_unittest.TestAllTypes.NestedMessage.bb',
- index=0, number=1,
- type=5, cpp_type=1, label=1,
- has_default_value=False, default_value=0,
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None),
- ],
- nested_types=[],
- enum_types=[
- self.my_enum,
- ],
- extensions=[])
+ package='protobuf_unittest')
+ message_proto = file_proto.message_type.add(
+ name='NestedMessage')
+ message_proto.field.add(
+ name='bb',
+ number=1,
+ type=descriptor_pb2.FieldDescriptorProto.TYPE_INT32,
+ label=descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL)
+ enum_proto = message_proto.enum_type.add(
+ name='ForeignEnum')
+ enum_proto.value.add(name='FOREIGN_FOO', number=4)
+ enum_proto.value.add(name='FOREIGN_BAR', number=5)
+ enum_proto.value.add(name='FOREIGN_BAZ', number=6)
+
+ descriptor_pool = symbol_database.Default().pool
+ descriptor_pool.Add(file_proto)
+ self.my_file = descriptor_pool.FindFileByName(file_proto.name)
+ self.my_message = self.my_file.message_types_by_name[message_proto.name]
+ self.my_enum = self.my_message.enum_types_by_name[enum_proto.name]
+
self.my_method = descriptor.MethodDescriptor(
name='Bar',
full_name='protobuf_unittest.TestService.Bar',
@@ -173,6 +164,11 @@ class DescriptorTest(basetest.TestCase):
self.assertEqual(unittest_custom_options_pb2.METHODOPT1_VAL2,
method_options.Extensions[method_opt1])
+ message_descriptor = (
+ unittest_custom_options_pb2.DummyMessageContainingEnum.DESCRIPTOR)
+ self.assertTrue(file_descriptor.has_options)
+ self.assertFalse(message_descriptor.has_options)
+
def testDifferentCustomOptionTypes(self):
kint32min = -2**31
kint64min = -2**63
@@ -394,6 +390,108 @@ class DescriptorTest(basetest.TestCase):
self.assertEqual(self.my_file.name, 'some/filename/some.proto')
self.assertEqual(self.my_file.package, 'protobuf_unittest')
+ @basetest.unittest.skipIf(
+ api_implementation.Type() != 'cpp' or api_implementation.Version() != 2,
+ 'Immutability of descriptors is only enforced in v2 implementation')
+ def testImmutableCppDescriptor(self):
+ message_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
+ with self.assertRaises(AttributeError):
+ message_descriptor.fields_by_name = None
+ with self.assertRaises(TypeError):
+ message_descriptor.fields_by_name['Another'] = None
+ with self.assertRaises(TypeError):
+ message_descriptor.fields.append(None)
+
+
+class GeneratedDescriptorTest(basetest.TestCase):
+ """Tests for the properties of descriptors in generated code."""
+
+ def CheckMessageDescriptor(self, message_descriptor):
+ # Basic properties
+ self.assertEqual(message_descriptor.name, 'TestAllTypes')
+ self.assertEqual(message_descriptor.full_name,
+ 'protobuf_unittest.TestAllTypes')
+ # Test equality and hashability
+ self.assertEqual(message_descriptor, message_descriptor)
+ self.assertEqual(message_descriptor.fields[0].containing_type,
+ message_descriptor)
+ self.assertIn(message_descriptor, [message_descriptor])
+ self.assertIn(message_descriptor, {message_descriptor: None})
+ # Test field containers
+ self.CheckDescriptorSequence(message_descriptor.fields)
+ self.CheckDescriptorMapping(message_descriptor.fields_by_name)
+ self.CheckDescriptorMapping(message_descriptor.fields_by_number)
+
+ def CheckFieldDescriptor(self, field_descriptor):
+ # Basic properties
+ self.assertEqual(field_descriptor.name, 'optional_int32')
+ self.assertEqual(field_descriptor.full_name,
+ 'protobuf_unittest.TestAllTypes.optional_int32')
+ self.assertEqual(field_descriptor.containing_type.name, 'TestAllTypes')
+ # Test equality and hashability
+ self.assertEqual(field_descriptor, field_descriptor)
+ self.assertEqual(
+ field_descriptor.containing_type.fields_by_name['optional_int32'],
+ field_descriptor)
+ self.assertIn(field_descriptor, [field_descriptor])
+ self.assertIn(field_descriptor, {field_descriptor: None})
+
+ def CheckDescriptorSequence(self, sequence):
+ # Verifies that a property like 'messageDescriptor.fields' has all the
+ # properties of an immutable abc.Sequence.
+ self.assertGreater(len(sequence), 0) # Sized
+ self.assertEqual(len(sequence), len(list(sequence))) # Iterable
+ item = sequence[0]
+ self.assertEqual(item, sequence[0])
+ self.assertIn(item, sequence) # Container
+ self.assertEqual(sequence.index(item), 0)
+ self.assertEqual(sequence.count(item), 1)
+ reversed_iterator = reversed(sequence)
+ self.assertEqual(list(reversed_iterator), list(sequence)[::-1])
+ self.assertRaises(StopIteration, next, reversed_iterator)
+
+ def CheckDescriptorMapping(self, mapping):
+ # Verifies that a property like 'messageDescriptor.fields' has all the
+ # properties of an immutable abc.Mapping.
+ self.assertGreater(len(mapping), 0) # Sized
+ self.assertEqual(len(mapping), len(list(mapping))) # Iterable
+ if sys.version_info.major >= 3:
+ key, item = next(iter(mapping.items()))
+ else:
+ key, item = mapping.items()[0]
+ self.assertIn(key, mapping) # Container
+ self.assertEqual(mapping.get(key), item)
+ # keys(), iterkeys() &co
+ item = (next(iter(mapping.keys())), next(iter(mapping.values())))
+ self.assertEqual(item, next(iter(mapping.items())))
+ if sys.version_info.major < 3:
+ def CheckItems(seq, iterator):
+ self.assertEqual(next(iterator), seq[0])
+ self.assertEqual(list(iterator), seq[1:])
+ CheckItems(mapping.keys(), mapping.iterkeys())
+ CheckItems(mapping.values(), mapping.itervalues())
+ CheckItems(mapping.items(), mapping.iteritems())
+
+ def testDescriptor(self):
+ message_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
+ self.CheckMessageDescriptor(message_descriptor)
+ field_descriptor = message_descriptor.fields_by_name['optional_int32']
+ self.CheckFieldDescriptor(field_descriptor)
+
+ def testCppDescriptorContainer(self):
+ # Check that the collection is still valid even if the parent disappeared.
+ enum = unittest_pb2.TestAllTypes.DESCRIPTOR.enum_types_by_name['NestedEnum']
+ values = enum.values
+ del enum
+ self.assertEqual('FOO', values[0].name)
+
+ def testCppDescriptorContainer_Iterator(self):
+ # Same test with the iterator
+ enum = unittest_pb2.TestAllTypes.DESCRIPTOR.enum_types_by_name['NestedEnum']
+ values_iter = iter(enum.values)
+ del enum
+ self.assertEqual('FOO', next(values_iter).name)
+
class DescriptorCopyToProtoTest(basetest.TestCase):
"""Tests for CopyTo functions of Descriptor."""
@@ -588,10 +686,12 @@ class DescriptorCopyToProtoTest(basetest.TestCase):
output_type: '.protobuf_unittest.BarResponse'
>
"""
- self._InternalTestCopyToProto(
- unittest_pb2.TestService.DESCRIPTOR,
- descriptor_pb2.ServiceDescriptorProto,
- TEST_SERVICE_ASCII)
+ # TODO(rocking): enable this test after the proto descriptor change is
+ # checked in.
+ #self._InternalTestCopyToProto(
+ # unittest_pb2.TestService.DESCRIPTOR,
+ # descriptor_pb2.ServiceDescriptorProto,
+ # TEST_SERVICE_ASCII)
class MakeDescriptorTest(basetest.TestCase):
diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py
index 42e2ad7e..7ab814cf 100755
--- a/python/google/protobuf/internal/message_test.py
+++ b/python/google/protobuf/internal/message_test.py
@@ -48,9 +48,12 @@ import math
import operator
import pickle
import sys
+import unittest
from google.apputils import basetest
+from google.apputils.pybase import parameterized
from google.protobuf import unittest_pb2
+from google.protobuf import unittest_proto3_arena_pb2
from google.protobuf.internal import api_implementation
from google.protobuf.internal import test_util
from google.protobuf import message
@@ -69,88 +72,65 @@ def IsNegInf(val):
return isinf(val) and (val < 0)
+@parameterized.Parameters(
+ (unittest_pb2),
+ (unittest_proto3_arena_pb2))
class MessageTest(basetest.TestCase):
- def testBadUtf8String(self):
+ def testBadUtf8String(self, message_module):
if api_implementation.Type() != 'python':
self.skipTest("Skipping testBadUtf8String, currently only the python "
"api implementation raises UnicodeDecodeError when a "
"string field contains bad utf-8.")
bad_utf8_data = test_util.GoldenFileData('bad_utf8_string')
with self.assertRaises(UnicodeDecodeError) as context:
- unittest_pb2.TestAllTypes.FromString(bad_utf8_data)
- self.assertIn('field: protobuf_unittest.TestAllTypes.optional_string',
- str(context.exception))
-
- def testGoldenMessage(self):
- golden_data = test_util.GoldenFileData(
- 'golden_message_oneof_implemented')
- golden_message = unittest_pb2.TestAllTypes()
- golden_message.ParseFromString(golden_data)
- test_util.ExpectAllFieldsSet(self, golden_message)
- self.assertEqual(golden_data, golden_message.SerializeToString())
- golden_copy = copy.deepcopy(golden_message)
- self.assertEqual(golden_data, golden_copy.SerializeToString())
+ message_module.TestAllTypes.FromString(bad_utf8_data)
+ self.assertIn('TestAllTypes.optional_string', str(context.exception))
+
+ def testGoldenMessage(self, message_module):
+ # Proto3 doesn't have the "default_foo" members or foreign enums,
+ # and doesn't preserve unknown fields, so for proto3 we use a golden
+ # message that doesn't have these fields set.
+ if message_module is unittest_pb2:
+ golden_data = test_util.GoldenFileData(
+ 'golden_message_oneof_implemented')
+ else:
+ golden_data = test_util.GoldenFileData('golden_message_proto3')
- def testGoldenExtensions(self):
- golden_data = test_util.GoldenFileData('golden_message')
- golden_message = unittest_pb2.TestAllExtensions()
+ golden_message = message_module.TestAllTypes()
golden_message.ParseFromString(golden_data)
- all_set = unittest_pb2.TestAllExtensions()
- test_util.SetAllExtensions(all_set)
- self.assertEquals(all_set, golden_message)
+ if message_module is unittest_pb2:
+ test_util.ExpectAllFieldsSet(self, golden_message)
self.assertEqual(golden_data, golden_message.SerializeToString())
golden_copy = copy.deepcopy(golden_message)
self.assertEqual(golden_data, golden_copy.SerializeToString())
- def testGoldenPackedMessage(self):
+ def testGoldenPackedMessage(self, message_module):
golden_data = test_util.GoldenFileData('golden_packed_fields_message')
- golden_message = unittest_pb2.TestPackedTypes()
+ golden_message = message_module.TestPackedTypes()
golden_message.ParseFromString(golden_data)
- all_set = unittest_pb2.TestPackedTypes()
+ all_set = message_module.TestPackedTypes()
test_util.SetAllPackedFields(all_set)
- self.assertEquals(all_set, golden_message)
+ self.assertEqual(all_set, golden_message)
self.assertEqual(golden_data, all_set.SerializeToString())
golden_copy = copy.deepcopy(golden_message)
self.assertEqual(golden_data, golden_copy.SerializeToString())
- def testGoldenPackedExtensions(self):
- golden_data = test_util.GoldenFileData('golden_packed_fields_message')
- golden_message = unittest_pb2.TestPackedExtensions()
- golden_message.ParseFromString(golden_data)
- all_set = unittest_pb2.TestPackedExtensions()
- test_util.SetAllPackedExtensions(all_set)
- self.assertEquals(all_set, golden_message)
- self.assertEqual(golden_data, all_set.SerializeToString())
- golden_copy = copy.deepcopy(golden_message)
- self.assertEqual(golden_data, golden_copy.SerializeToString())
-
- def testPickleSupport(self):
+ def testPickleSupport(self, message_module):
golden_data = test_util.GoldenFileData('golden_message')
- golden_message = unittest_pb2.TestAllTypes()
+ golden_message = message_module.TestAllTypes()
golden_message.ParseFromString(golden_data)
pickled_message = pickle.dumps(golden_message)
unpickled_message = pickle.loads(pickled_message)
- self.assertEquals(unpickled_message, golden_message)
-
-
- def testPickleIncompleteProto(self):
- golden_message = unittest_pb2.TestRequired(a=1)
- pickled_message = pickle.dumps(golden_message)
-
- unpickled_message = pickle.loads(pickled_message)
- self.assertEquals(unpickled_message, golden_message)
- self.assertEquals(unpickled_message.a, 1)
- # This is still an incomplete proto - so serializing should fail
- self.assertRaises(message.EncodeError, unpickled_message.SerializeToString)
+ self.assertEqual(unpickled_message, golden_message)
- def testPositiveInfinity(self):
+ def testPositiveInfinity(self, message_module):
golden_data = (b'\x5D\x00\x00\x80\x7F'
b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
b'\xCD\x02\x00\x00\x80\x7F'
b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F')
- golden_message = unittest_pb2.TestAllTypes()
+ golden_message = message_module.TestAllTypes()
golden_message.ParseFromString(golden_data)
self.assertTrue(IsPosInf(golden_message.optional_float))
self.assertTrue(IsPosInf(golden_message.optional_double))
@@ -158,12 +138,12 @@ class MessageTest(basetest.TestCase):
self.assertTrue(IsPosInf(golden_message.repeated_double[0]))
self.assertEqual(golden_data, golden_message.SerializeToString())
- def testNegativeInfinity(self):
+ def testNegativeInfinity(self, message_module):
golden_data = (b'\x5D\x00\x00\x80\xFF'
b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
b'\xCD\x02\x00\x00\x80\xFF'
b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF')
- golden_message = unittest_pb2.TestAllTypes()
+ golden_message = message_module.TestAllTypes()
golden_message.ParseFromString(golden_data)
self.assertTrue(IsNegInf(golden_message.optional_float))
self.assertTrue(IsNegInf(golden_message.optional_double))
@@ -171,12 +151,12 @@ class MessageTest(basetest.TestCase):
self.assertTrue(IsNegInf(golden_message.repeated_double[0]))
self.assertEqual(golden_data, golden_message.SerializeToString())
- def testNotANumber(self):
+ def testNotANumber(self, message_module):
golden_data = (b'\x5D\x00\x00\xC0\x7F'
b'\x61\x00\x00\x00\x00\x00\x00\xF8\x7F'
b'\xCD\x02\x00\x00\xC0\x7F'
b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F')
- golden_message = unittest_pb2.TestAllTypes()
+ golden_message = message_module.TestAllTypes()
golden_message.ParseFromString(golden_data)
self.assertTrue(isnan(golden_message.optional_float))
self.assertTrue(isnan(golden_message.optional_double))
@@ -188,47 +168,47 @@ class MessageTest(basetest.TestCase):
# verify the serialized string can be converted into a correctly
# behaving protocol buffer.
serialized = golden_message.SerializeToString()
- message = unittest_pb2.TestAllTypes()
+ message = message_module.TestAllTypes()
message.ParseFromString(serialized)
self.assertTrue(isnan(message.optional_float))
self.assertTrue(isnan(message.optional_double))
self.assertTrue(isnan(message.repeated_float[0]))
self.assertTrue(isnan(message.repeated_double[0]))
- def testPositiveInfinityPacked(self):
+ def testPositiveInfinityPacked(self, message_module):
golden_data = (b'\xA2\x06\x04\x00\x00\x80\x7F'
b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F')
- golden_message = unittest_pb2.TestPackedTypes()
+ golden_message = message_module.TestPackedTypes()
golden_message.ParseFromString(golden_data)
self.assertTrue(IsPosInf(golden_message.packed_float[0]))
self.assertTrue(IsPosInf(golden_message.packed_double[0]))
self.assertEqual(golden_data, golden_message.SerializeToString())
- def testNegativeInfinityPacked(self):
+ def testNegativeInfinityPacked(self, message_module):
golden_data = (b'\xA2\x06\x04\x00\x00\x80\xFF'
b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\xFF')
- golden_message = unittest_pb2.TestPackedTypes()
+ golden_message = message_module.TestPackedTypes()
golden_message.ParseFromString(golden_data)
self.assertTrue(IsNegInf(golden_message.packed_float[0]))
self.assertTrue(IsNegInf(golden_message.packed_double[0]))
self.assertEqual(golden_data, golden_message.SerializeToString())
- def testNotANumberPacked(self):
+ def testNotANumberPacked(self, message_module):
golden_data = (b'\xA2\x06\x04\x00\x00\xC0\x7F'
b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF8\x7F')
- golden_message = unittest_pb2.TestPackedTypes()
+ golden_message = message_module.TestPackedTypes()
golden_message.ParseFromString(golden_data)
self.assertTrue(isnan(golden_message.packed_float[0]))
self.assertTrue(isnan(golden_message.packed_double[0]))
serialized = golden_message.SerializeToString()
- message = unittest_pb2.TestPackedTypes()
+ message = message_module.TestPackedTypes()
message.ParseFromString(serialized)
self.assertTrue(isnan(message.packed_float[0]))
self.assertTrue(isnan(message.packed_double[0]))
- def testExtremeFloatValues(self):
- message = unittest_pb2.TestAllTypes()
+ def testExtremeFloatValues(self, message_module):
+ message = message_module.TestAllTypes()
# Most positive exponent, no significand bits set.
kMostPosExponentNoSigBits = math.pow(2, 127)
@@ -272,8 +252,8 @@ class MessageTest(basetest.TestCase):
message.ParseFromString(message.SerializeToString())
self.assertTrue(message.optional_float == -kMostNegExponentOneSigBit)
- def testExtremeDoubleValues(self):
- message = unittest_pb2.TestAllTypes()
+ def testExtremeDoubleValues(self, message_module):
+ message = message_module.TestAllTypes()
# Most positive exponent, no significand bits set.
kMostPosExponentNoSigBits = math.pow(2, 1023)
@@ -317,43 +297,43 @@ class MessageTest(basetest.TestCase):
message.ParseFromString(message.SerializeToString())
self.assertTrue(message.optional_double == -kMostNegExponentOneSigBit)
- def testFloatPrinting(self):
- message = unittest_pb2.TestAllTypes()
+ def testFloatPrinting(self, message_module):
+ message = message_module.TestAllTypes()
message.optional_float = 2.0
self.assertEqual(str(message), 'optional_float: 2.0\n')
- def testHighPrecisionFloatPrinting(self):
- message = unittest_pb2.TestAllTypes()
+ def testHighPrecisionFloatPrinting(self, message_module):
+ message = message_module.TestAllTypes()
message.optional_double = 0.12345678912345678
if sys.version_info.major >= 3:
self.assertEqual(str(message), 'optional_double: 0.12345678912345678\n')
else:
self.assertEqual(str(message), 'optional_double: 0.123456789123\n')
- def testUnknownFieldPrinting(self):
- populated = unittest_pb2.TestAllTypes()
+ def testUnknownFieldPrinting(self, message_module):
+ populated = message_module.TestAllTypes()
test_util.SetAllNonLazyFields(populated)
- empty = unittest_pb2.TestEmptyMessage()
+ empty = message_module.TestEmptyMessage()
empty.ParseFromString(populated.SerializeToString())
self.assertEqual(str(empty), '')
- def testRepeatedNestedFieldIteration(self):
- msg = unittest_pb2.TestAllTypes()
+ def testRepeatedNestedFieldIteration(self, message_module):
+ msg = message_module.TestAllTypes()
msg.repeated_nested_message.add(bb=1)
msg.repeated_nested_message.add(bb=2)
msg.repeated_nested_message.add(bb=3)
msg.repeated_nested_message.add(bb=4)
- self.assertEquals([1, 2, 3, 4],
- [m.bb for m in msg.repeated_nested_message])
- self.assertEquals([4, 3, 2, 1],
- [m.bb for m in reversed(msg.repeated_nested_message)])
- self.assertEquals([4, 3, 2, 1],
- [m.bb for m in msg.repeated_nested_message[::-1]])
+ self.assertEqual([1, 2, 3, 4],
+ [m.bb for m in msg.repeated_nested_message])
+ self.assertEqual([4, 3, 2, 1],
+ [m.bb for m in reversed(msg.repeated_nested_message)])
+ self.assertEqual([4, 3, 2, 1],
+ [m.bb for m in msg.repeated_nested_message[::-1]])
- def testSortingRepeatedScalarFieldsDefaultComparator(self):
+ def testSortingRepeatedScalarFieldsDefaultComparator(self, message_module):
"""Check some different types with the default comparator."""
- message = unittest_pb2.TestAllTypes()
+ message = message_module.TestAllTypes()
# TODO(mattp): would testing more scalar types strengthen test?
message.repeated_int32.append(1)
@@ -388,9 +368,9 @@ class MessageTest(basetest.TestCase):
self.assertEqual(message.repeated_bytes[1], b'b')
self.assertEqual(message.repeated_bytes[2], b'c')
- def testSortingRepeatedScalarFieldsCustomComparator(self):
+ def testSortingRepeatedScalarFieldsCustomComparator(self, message_module):
"""Check some different types with custom comparator."""
- message = unittest_pb2.TestAllTypes()
+ message = message_module.TestAllTypes()
message.repeated_int32.append(-3)
message.repeated_int32.append(-2)
@@ -408,9 +388,9 @@ class MessageTest(basetest.TestCase):
self.assertEqual(message.repeated_string[1], 'bb')
self.assertEqual(message.repeated_string[2], 'aaa')
- def testSortingRepeatedCompositeFieldsCustomComparator(self):
+ def testSortingRepeatedCompositeFieldsCustomComparator(self, message_module):
"""Check passing a custom comparator to sort a repeated composite field."""
- message = unittest_pb2.TestAllTypes()
+ message = message_module.TestAllTypes()
message.repeated_nested_message.add().bb = 1
message.repeated_nested_message.add().bb = 3
@@ -426,9 +406,9 @@ class MessageTest(basetest.TestCase):
self.assertEqual(message.repeated_nested_message[4].bb, 5)
self.assertEqual(message.repeated_nested_message[5].bb, 6)
- def testRepeatedCompositeFieldSortArguments(self):
+ def testRepeatedCompositeFieldSortArguments(self, message_module):
"""Check sorting a repeated composite field using list.sort() arguments."""
- message = unittest_pb2.TestAllTypes()
+ message = message_module.TestAllTypes()
get_bb = operator.attrgetter('bb')
cmp_bb = lambda a, b: cmp(a.bb, b.bb)
@@ -452,9 +432,9 @@ class MessageTest(basetest.TestCase):
self.assertEqual([k.bb for k in message.repeated_nested_message],
[6, 5, 4, 3, 2, 1])
- def testRepeatedScalarFieldSortArguments(self):
+ def testRepeatedScalarFieldSortArguments(self, message_module):
"""Check sorting a scalar field using list.sort() arguments."""
- message = unittest_pb2.TestAllTypes()
+ message = message_module.TestAllTypes()
message.repeated_int32.append(-3)
message.repeated_int32.append(-2)
@@ -484,9 +464,9 @@ class MessageTest(basetest.TestCase):
message.repeated_string.sort(cmp=len_cmp, reverse=True)
self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c'])
- def testRepeatedFieldsComparable(self):
- m1 = unittest_pb2.TestAllTypes()
- m2 = unittest_pb2.TestAllTypes()
+ def testRepeatedFieldsComparable(self, message_module):
+ m1 = message_module.TestAllTypes()
+ m2 = message_module.TestAllTypes()
m1.repeated_int32.append(0)
m1.repeated_int32.append(1)
m1.repeated_int32.append(2)
@@ -519,55 +499,6 @@ class MessageTest(basetest.TestCase):
# TODO(anuraag): Implement extensiondict comparison in C++ and then add test
- def testParsingMerge(self):
- """Check the merge behavior when a required or optional field appears
- multiple times in the input."""
- messages = [
- unittest_pb2.TestAllTypes(),
- unittest_pb2.TestAllTypes(),
- unittest_pb2.TestAllTypes() ]
- messages[0].optional_int32 = 1
- messages[1].optional_int64 = 2
- messages[2].optional_int32 = 3
- messages[2].optional_string = 'hello'
-
- merged_message = unittest_pb2.TestAllTypes()
- merged_message.optional_int32 = 3
- merged_message.optional_int64 = 2
- merged_message.optional_string = 'hello'
-
- generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator()
- generator.field1.extend(messages)
- generator.field2.extend(messages)
- generator.field3.extend(messages)
- generator.ext1.extend(messages)
- generator.ext2.extend(messages)
- generator.group1.add().field1.MergeFrom(messages[0])
- generator.group1.add().field1.MergeFrom(messages[1])
- generator.group1.add().field1.MergeFrom(messages[2])
- generator.group2.add().field1.MergeFrom(messages[0])
- generator.group2.add().field1.MergeFrom(messages[1])
- generator.group2.add().field1.MergeFrom(messages[2])
-
- data = generator.SerializeToString()
- parsing_merge = unittest_pb2.TestParsingMerge()
- parsing_merge.ParseFromString(data)
-
- # Required and optional fields should be merged.
- self.assertEqual(parsing_merge.required_all_types, merged_message)
- self.assertEqual(parsing_merge.optional_all_types, merged_message)
- self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types,
- merged_message)
- self.assertEqual(parsing_merge.Extensions[
- unittest_pb2.TestParsingMerge.optional_ext],
- merged_message)
-
- # Repeated fields should not be merged.
- self.assertEqual(len(parsing_merge.repeated_all_types), 3)
- self.assertEqual(len(parsing_merge.repeatedgroup), 3)
- self.assertEqual(len(parsing_merge.Extensions[
- unittest_pb2.TestParsingMerge.repeated_ext]), 3)
-
def ensureNestedMessageExists(self, msg, attribute):
"""Make sure that a nested message object exists.
@@ -577,12 +508,28 @@ class MessageTest(basetest.TestCase):
getattr(msg, attribute)
self.assertFalse(msg.HasField(attribute))
- def testOneofGetCaseNonexistingField(self):
- m = unittest_pb2.TestAllTypes()
+ def testOneofGetCaseNonexistingField(self, message_module):
+ m = message_module.TestAllTypes()
self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field')
- def testOneofSemantics(self):
- m = unittest_pb2.TestAllTypes()
+ def testOneofDefaultValues(self, message_module):
+ m = message_module.TestAllTypes()
+ self.assertIs(None, m.WhichOneof('oneof_field'))
+ self.assertFalse(m.HasField('oneof_uint32'))
+
+ # Oneof is set even when setting it to a default value.
+ m.oneof_uint32 = 0
+ self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
+ self.assertTrue(m.HasField('oneof_uint32'))
+ self.assertFalse(m.HasField('oneof_string'))
+
+ m.oneof_string = ""
+ self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
+ self.assertTrue(m.HasField('oneof_string'))
+ self.assertFalse(m.HasField('oneof_uint32'))
+
+ def testOneofSemantics(self, message_module):
+ m = message_module.TestAllTypes()
self.assertIs(None, m.WhichOneof('oneof_field'))
m.oneof_uint32 = 11
@@ -604,96 +551,569 @@ class MessageTest(basetest.TestCase):
self.assertFalse(m.HasField('oneof_nested_message'))
self.assertTrue(m.HasField('oneof_bytes'))
- def testOneofCompositeFieldReadAccess(self):
- m = unittest_pb2.TestAllTypes()
+ def testOneofCompositeFieldReadAccess(self, message_module):
+ m = message_module.TestAllTypes()
m.oneof_uint32 = 11
self.ensureNestedMessageExists(m, 'oneof_nested_message')
self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
self.assertEqual(11, m.oneof_uint32)
- def testOneofHasField(self):
- m = unittest_pb2.TestAllTypes()
- self.assertFalse(m.HasField('oneof_field'))
+ def testOneofWhichOneof(self, message_module):
+ m = message_module.TestAllTypes()
+ self.assertIs(None, m.WhichOneof('oneof_field'))
+ if message_module is unittest_pb2:
+ self.assertFalse(m.HasField('oneof_field'))
+
m.oneof_uint32 = 11
- self.assertTrue(m.HasField('oneof_field'))
+ self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
+ if message_module is unittest_pb2:
+ self.assertTrue(m.HasField('oneof_field'))
+
m.oneof_bytes = b'bb'
- self.assertTrue(m.HasField('oneof_field'))
+ self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
+
m.ClearField('oneof_bytes')
- self.assertFalse(m.HasField('oneof_field'))
+ self.assertIs(None, m.WhichOneof('oneof_field'))
+ if message_module is unittest_pb2:
+ self.assertFalse(m.HasField('oneof_field'))
- def testOneofClearField(self):
- m = unittest_pb2.TestAllTypes()
+ def testOneofClearField(self, message_module):
+ m = message_module.TestAllTypes()
m.oneof_uint32 = 11
m.ClearField('oneof_field')
- self.assertFalse(m.HasField('oneof_field'))
+ if message_module is unittest_pb2:
+ self.assertFalse(m.HasField('oneof_field'))
self.assertFalse(m.HasField('oneof_uint32'))
self.assertIs(None, m.WhichOneof('oneof_field'))
- def testOneofClearSetField(self):
- m = unittest_pb2.TestAllTypes()
+ def testOneofClearSetField(self, message_module):
+ m = message_module.TestAllTypes()
m.oneof_uint32 = 11
m.ClearField('oneof_uint32')
- self.assertFalse(m.HasField('oneof_field'))
+ if message_module is unittest_pb2:
+ self.assertFalse(m.HasField('oneof_field'))
self.assertFalse(m.HasField('oneof_uint32'))
self.assertIs(None, m.WhichOneof('oneof_field'))
- def testOneofClearUnsetField(self):
- m = unittest_pb2.TestAllTypes()
+ def testOneofClearUnsetField(self, message_module):
+ m = message_module.TestAllTypes()
m.oneof_uint32 = 11
self.ensureNestedMessageExists(m, 'oneof_nested_message')
m.ClearField('oneof_nested_message')
self.assertEqual(11, m.oneof_uint32)
- self.assertTrue(m.HasField('oneof_field'))
+ if message_module is unittest_pb2:
+ self.assertTrue(m.HasField('oneof_field'))
self.assertTrue(m.HasField('oneof_uint32'))
self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
- def testOneofDeserialize(self):
- m = unittest_pb2.TestAllTypes()
+ def testOneofDeserialize(self, message_module):
+ m = message_module.TestAllTypes()
m.oneof_uint32 = 11
- m2 = unittest_pb2.TestAllTypes()
+ m2 = message_module.TestAllTypes()
m2.ParseFromString(m.SerializeToString())
self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
- def testOneofCopyFrom(self):
- m = unittest_pb2.TestAllTypes()
+ def testOneofCopyFrom(self, message_module):
+ m = message_module.TestAllTypes()
m.oneof_uint32 = 11
- m2 = unittest_pb2.TestAllTypes()
+ m2 = message_module.TestAllTypes()
m2.CopyFrom(m)
self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
- def testOneofNestedMergeFrom(self):
- m = unittest_pb2.NestedTestAllTypes()
+ def testOneofNestedMergeFrom(self, message_module):
+ m = message_module.NestedTestAllTypes()
m.payload.oneof_uint32 = 11
- m2 = unittest_pb2.NestedTestAllTypes()
+ m2 = message_module.NestedTestAllTypes()
m2.payload.oneof_bytes = b'bb'
m2.child.payload.oneof_bytes = b'bb'
m2.MergeFrom(m)
self.assertEqual('oneof_uint32', m2.payload.WhichOneof('oneof_field'))
self.assertEqual('oneof_bytes', m2.child.payload.WhichOneof('oneof_field'))
- def testOneofClear(self):
- m = unittest_pb2.TestAllTypes()
+ def testOneofMessageMergeFrom(self, message_module):
+ m = message_module.NestedTestAllTypes()
+ m.payload.oneof_nested_message.bb = 11
+ m.child.payload.oneof_nested_message.bb = 12
+ m2 = message_module.NestedTestAllTypes()
+ m2.payload.oneof_uint32 = 13
+ m2.MergeFrom(m)
+ self.assertEqual('oneof_nested_message',
+ m2.payload.WhichOneof('oneof_field'))
+ self.assertEqual('oneof_nested_message',
+ m2.child.payload.WhichOneof('oneof_field'))
+
+ def testOneofNestedMessageInit(self, message_module):
+ m = message_module.TestAllTypes(
+ oneof_nested_message=message_module.TestAllTypes.NestedMessage())
+ self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field'))
+
+ def testOneofClear(self, message_module):
+ m = message_module.TestAllTypes()
m.oneof_uint32 = 11
m.Clear()
self.assertIsNone(m.WhichOneof('oneof_field'))
m.oneof_bytes = b'bb'
- self.assertTrue(m.HasField('oneof_field'))
+ self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
+
+ def testAssignByteStringToUnicodeField(self, message_module):
+ """Assigning a byte string to a string field should result
+ in the value being converted to a Unicode string."""
+ m = message_module.TestAllTypes()
+ m.optional_string = str('')
+ self.assertTrue(isinstance(m.optional_string, unicode))
+# TODO(haberman): why are these tests Google-internal only?
- def testSortEmptyRepeatedCompositeContainer(self):
+ def testLongValuedSlice(self, message_module):
+ """It should be possible to use long-valued indicies in slices
+
+ This didn't used to work in the v2 C++ implementation.
+ """
+ m = message_module.TestAllTypes()
+
+ # Repeated scalar
+ m.repeated_int32.append(1)
+ sl = m.repeated_int32[long(0):long(len(m.repeated_int32))]
+ self.assertEqual(len(m.repeated_int32), len(sl))
+
+ # Repeated composite
+ m.repeated_nested_message.add().bb = 3
+ sl = m.repeated_nested_message[long(0):long(len(m.repeated_nested_message))]
+ self.assertEqual(len(m.repeated_nested_message), len(sl))
+
+ def testExtendShouldNotSwallowExceptions(self, message_module):
+ """This didn't use to work in the v2 C++ implementation."""
+ m = message_module.TestAllTypes()
+ with self.assertRaises(NameError) as _:
+ m.repeated_int32.extend(a for i in range(10)) # pylint: disable=undefined-variable
+ with self.assertRaises(NameError) as _:
+ m.repeated_nested_enum.extend(
+ a for i in range(10)) # pylint: disable=undefined-variable
+
+ FALSY_VALUES = [None, False, 0, 0.0, b'', u'', bytearray(), [], {}, set()]
+
+ def testExtendInt32WithNothing(self, message_module):
+ """Test no-ops extending repeated int32 fields."""
+ m = message_module.TestAllTypes()
+ self.assertSequenceEqual([], m.repeated_int32)
+
+ # TODO(ptucker): Deprecate this behavior. b/18413862
+ for falsy_value in MessageTest.FALSY_VALUES:
+ m.repeated_int32.extend(falsy_value)
+ self.assertSequenceEqual([], m.repeated_int32)
+
+ m.repeated_int32.extend([])
+ self.assertSequenceEqual([], m.repeated_int32)
+
+ def testExtendFloatWithNothing(self, message_module):
+ """Test no-ops extending repeated float fields."""
+ m = message_module.TestAllTypes()
+ self.assertSequenceEqual([], m.repeated_float)
+
+ # TODO(ptucker): Deprecate this behavior. b/18413862
+ for falsy_value in MessageTest.FALSY_VALUES:
+ m.repeated_float.extend(falsy_value)
+ self.assertSequenceEqual([], m.repeated_float)
+
+ m.repeated_float.extend([])
+ self.assertSequenceEqual([], m.repeated_float)
+
+ def testExtendStringWithNothing(self, message_module):
+ """Test no-ops extending repeated string fields."""
+ m = message_module.TestAllTypes()
+ self.assertSequenceEqual([], m.repeated_string)
+
+ # TODO(ptucker): Deprecate this behavior. b/18413862
+ for falsy_value in MessageTest.FALSY_VALUES:
+ m.repeated_string.extend(falsy_value)
+ self.assertSequenceEqual([], m.repeated_string)
+
+ m.repeated_string.extend([])
+ self.assertSequenceEqual([], m.repeated_string)
+
+ def testExtendInt32WithPythonList(self, message_module):
+ """Test extending repeated int32 fields with python lists."""
+ m = message_module.TestAllTypes()
+ self.assertSequenceEqual([], m.repeated_int32)
+ m.repeated_int32.extend([0])
+ self.assertSequenceEqual([0], m.repeated_int32)
+ m.repeated_int32.extend([1, 2])
+ self.assertSequenceEqual([0, 1, 2], m.repeated_int32)
+ m.repeated_int32.extend([3, 4])
+ self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32)
+
+ def testExtendFloatWithPythonList(self, message_module):
+ """Test extending repeated float fields with python lists."""
+ m = message_module.TestAllTypes()
+ self.assertSequenceEqual([], m.repeated_float)
+ m.repeated_float.extend([0.0])
+ self.assertSequenceEqual([0.0], m.repeated_float)
+ m.repeated_float.extend([1.0, 2.0])
+ self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float)
+ m.repeated_float.extend([3.0, 4.0])
+ self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float)
+
+ def testExtendStringWithPythonList(self, message_module):
+ """Test extending repeated string fields with python lists."""
+ m = message_module.TestAllTypes()
+ self.assertSequenceEqual([], m.repeated_string)
+ m.repeated_string.extend([''])
+ self.assertSequenceEqual([''], m.repeated_string)
+ m.repeated_string.extend(['11', '22'])
+ self.assertSequenceEqual(['', '11', '22'], m.repeated_string)
+ m.repeated_string.extend(['33', '44'])
+ self.assertSequenceEqual(['', '11', '22', '33', '44'], m.repeated_string)
+
+ def testExtendStringWithString(self, message_module):
+ """Test extending repeated string fields with characters from a string."""
+ m = message_module.TestAllTypes()
+ self.assertSequenceEqual([], m.repeated_string)
+ m.repeated_string.extend('abc')
+ self.assertSequenceEqual(['a', 'b', 'c'], m.repeated_string)
+
+ class TestIterable(object):
+ """This iterable object mimics the behavior of numpy.array.
+
+ __nonzero__ fails for length > 1, and returns bool(item[0]) for length == 1.
+
+ """
+
+ def __init__(self, values=None):
+ self._list = values or []
+
+ def __nonzero__(self):
+ size = len(self._list)
+ if size == 0:
+ return False
+ if size == 1:
+ return bool(self._list[0])
+ raise ValueError('Truth value is ambiguous.')
+
+ def __len__(self):
+ return len(self._list)
+
+ def __iter__(self):
+ return self._list.__iter__()
+
+ def testExtendInt32WithIterable(self, message_module):
+ """Test extending repeated int32 fields with iterable."""
+ m = message_module.TestAllTypes()
+ self.assertSequenceEqual([], m.repeated_int32)
+ m.repeated_int32.extend(MessageTest.TestIterable([]))
+ self.assertSequenceEqual([], m.repeated_int32)
+ m.repeated_int32.extend(MessageTest.TestIterable([0]))
+ self.assertSequenceEqual([0], m.repeated_int32)
+ m.repeated_int32.extend(MessageTest.TestIterable([1, 2]))
+ self.assertSequenceEqual([0, 1, 2], m.repeated_int32)
+ m.repeated_int32.extend(MessageTest.TestIterable([3, 4]))
+ self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32)
+
+ def testExtendFloatWithIterable(self, message_module):
+ """Test extending repeated float fields with iterable."""
+ m = message_module.TestAllTypes()
+ self.assertSequenceEqual([], m.repeated_float)
+ m.repeated_float.extend(MessageTest.TestIterable([]))
+ self.assertSequenceEqual([], m.repeated_float)
+ m.repeated_float.extend(MessageTest.TestIterable([0.0]))
+ self.assertSequenceEqual([0.0], m.repeated_float)
+ m.repeated_float.extend(MessageTest.TestIterable([1.0, 2.0]))
+ self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float)
+ m.repeated_float.extend(MessageTest.TestIterable([3.0, 4.0]))
+ self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float)
+
+ def testExtendStringWithIterable(self, message_module):
+ """Test extending repeated string fields with iterable."""
+ m = message_module.TestAllTypes()
+ self.assertSequenceEqual([], m.repeated_string)
+ m.repeated_string.extend(MessageTest.TestIterable([]))
+ self.assertSequenceEqual([], m.repeated_string)
+ m.repeated_string.extend(MessageTest.TestIterable(['']))
+ self.assertSequenceEqual([''], m.repeated_string)
+ m.repeated_string.extend(MessageTest.TestIterable(['1', '2']))
+ self.assertSequenceEqual(['', '1', '2'], m.repeated_string)
+ m.repeated_string.extend(MessageTest.TestIterable(['3', '4']))
+ self.assertSequenceEqual(['', '1', '2', '3', '4'], m.repeated_string)
+
+ def testPickleRepeatedScalarContainer(self, message_module):
+ # TODO(tibell): The pure-Python implementation support pickling of
+ # scalar containers in *some* cases. For now the cpp2 version
+ # throws an exception to avoid a segfault. Investigate if we
+ # want to support pickling of these fields.
+ #
+ # For more information see: https://b2.corp.google.com/u/0/issues/18677897
+ if (api_implementation.Type() != 'cpp' or
+ api_implementation.Version() == 2):
+ return
+ m = message_module.TestAllTypes()
+ with self.assertRaises(pickle.PickleError) as _:
+ pickle.dumps(m.repeated_int32, pickle.HIGHEST_PROTOCOL)
+
+
+ def testSortEmptyRepeatedCompositeContainer(self, message_module):
"""Exercise a scenario that has led to segfaults in the past.
"""
- m = unittest_pb2.TestAllTypes()
+ m = message_module.TestAllTypes()
m.repeated_nested_message.sort()
- def testHasFieldOnRepeatedField(self):
+ def testHasFieldOnRepeatedField(self, message_module):
"""Using HasField on a repeated field should raise an exception.
"""
- m = unittest_pb2.TestAllTypes()
+ m = message_module.TestAllTypes()
with self.assertRaises(ValueError) as _:
m.HasField('repeated_int32')
+ def testRepeatedScalarFieldPop(self, message_module):
+ m = message_module.TestAllTypes()
+ with self.assertRaises(IndexError) as _:
+ m.repeated_int32.pop()
+ m.repeated_int32.extend(range(5))
+ self.assertEqual(4, m.repeated_int32.pop())
+ self.assertEqual(0, m.repeated_int32.pop(0))
+ self.assertEqual(2, m.repeated_int32.pop(1))
+ self.assertEqual([1, 3], m.repeated_int32)
+
+ def testRepeatedCompositeFieldPop(self, message_module):
+ m = message_module.TestAllTypes()
+ with self.assertRaises(IndexError) as _:
+ m.repeated_nested_message.pop()
+ for i in range(5):
+ n = m.repeated_nested_message.add()
+ n.bb = i
+ self.assertEqual(4, m.repeated_nested_message.pop().bb)
+ self.assertEqual(0, m.repeated_nested_message.pop(0).bb)
+ self.assertEqual(2, m.repeated_nested_message.pop(1).bb)
+ self.assertEqual([1, 3], [n.bb for n in m.repeated_nested_message])
+
+
+# Class to test proto2-only features (required, extensions, etc.)
+class Proto2Test(basetest.TestCase):
+
+ def testFieldPresence(self):
+ message = unittest_pb2.TestAllTypes()
+
+ self.assertFalse(message.HasField("optional_int32"))
+ self.assertFalse(message.HasField("optional_bool"))
+ self.assertFalse(message.HasField("optional_nested_message"))
+
+ with self.assertRaises(ValueError):
+ message.HasField("field_doesnt_exist")
+
+ with self.assertRaises(ValueError):
+ message.HasField("repeated_int32")
+ with self.assertRaises(ValueError):
+ message.HasField("repeated_nested_message")
+
+ self.assertEqual(0, message.optional_int32)
+ self.assertEqual(False, message.optional_bool)
+ self.assertEqual(0, message.optional_nested_message.bb)
+
+ # Fields are set even when setting the values to default values.
+ message.optional_int32 = 0
+ message.optional_bool = False
+ message.optional_nested_message.bb = 0
+ self.assertTrue(message.HasField("optional_int32"))
+ self.assertTrue(message.HasField("optional_bool"))
+ self.assertTrue(message.HasField("optional_nested_message"))
+
+ # Set the fields to non-default values.
+ message.optional_int32 = 5
+ message.optional_bool = True
+ message.optional_nested_message.bb = 15
+
+ self.assertTrue(message.HasField("optional_int32"))
+ self.assertTrue(message.HasField("optional_bool"))
+ self.assertTrue(message.HasField("optional_nested_message"))
+
+ # Clearing the fields unsets them and resets their value to default.
+ message.ClearField("optional_int32")
+ message.ClearField("optional_bool")
+ message.ClearField("optional_nested_message")
+
+ self.assertFalse(message.HasField("optional_int32"))
+ self.assertFalse(message.HasField("optional_bool"))
+ self.assertFalse(message.HasField("optional_nested_message"))
+ self.assertEqual(0, message.optional_int32)
+ self.assertEqual(False, message.optional_bool)
+ self.assertEqual(0, message.optional_nested_message.bb)
+
+ # TODO(tibell): The C++ implementations actually allows assignment
+ # of unknown enum values to *scalar* fields (but not repeated
+ # fields). Once checked enum fields becomes the default in the
+ # Python implementation, the C++ implementation should follow suit.
+ def testAssignInvalidEnum(self):
+ """It should not be possible to assign an invalid enum number to an
+ enum field."""
+ m = unittest_pb2.TestAllTypes()
+
+ with self.assertRaises(ValueError) as _:
+ m.optional_nested_enum = 1234567
+ self.assertRaises(ValueError, m.repeated_nested_enum.append, 1234567)
+
+ def testGoldenExtensions(self):
+ golden_data = test_util.GoldenFileData('golden_message')
+ golden_message = unittest_pb2.TestAllExtensions()
+ golden_message.ParseFromString(golden_data)
+ all_set = unittest_pb2.TestAllExtensions()
+ test_util.SetAllExtensions(all_set)
+ self.assertEqual(all_set, golden_message)
+ self.assertEqual(golden_data, golden_message.SerializeToString())
+ golden_copy = copy.deepcopy(golden_message)
+ self.assertEqual(golden_data, golden_copy.SerializeToString())
+
+ def testGoldenPackedExtensions(self):
+ golden_data = test_util.GoldenFileData('golden_packed_fields_message')
+ golden_message = unittest_pb2.TestPackedExtensions()
+ golden_message.ParseFromString(golden_data)
+ all_set = unittest_pb2.TestPackedExtensions()
+ test_util.SetAllPackedExtensions(all_set)
+ self.assertEqual(all_set, golden_message)
+ self.assertEqual(golden_data, all_set.SerializeToString())
+ golden_copy = copy.deepcopy(golden_message)
+ self.assertEqual(golden_data, golden_copy.SerializeToString())
+
+ def testPickleIncompleteProto(self):
+ golden_message = unittest_pb2.TestRequired(a=1)
+ pickled_message = pickle.dumps(golden_message)
+
+ unpickled_message = pickle.loads(pickled_message)
+ self.assertEqual(unpickled_message, golden_message)
+ self.assertEqual(unpickled_message.a, 1)
+ # This is still an incomplete proto - so serializing should fail
+ self.assertRaises(message.EncodeError, unpickled_message.SerializeToString)
+
+
+ # TODO(haberman): this isn't really a proto2-specific test except that this
+ # message has a required field in it. Should probably be factored out so
+ # that we can test the other parts with proto3.
+ def testParsingMerge(self):
+ """Check the merge behavior when a required or optional field appears
+ multiple times in the input."""
+ messages = [
+ unittest_pb2.TestAllTypes(),
+ unittest_pb2.TestAllTypes(),
+ unittest_pb2.TestAllTypes() ]
+ messages[0].optional_int32 = 1
+ messages[1].optional_int64 = 2
+ messages[2].optional_int32 = 3
+ messages[2].optional_string = 'hello'
+
+ merged_message = unittest_pb2.TestAllTypes()
+ merged_message.optional_int32 = 3
+ merged_message.optional_int64 = 2
+ merged_message.optional_string = 'hello'
+
+ generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator()
+ generator.field1.extend(messages)
+ generator.field2.extend(messages)
+ generator.field3.extend(messages)
+ generator.ext1.extend(messages)
+ generator.ext2.extend(messages)
+ generator.group1.add().field1.MergeFrom(messages[0])
+ generator.group1.add().field1.MergeFrom(messages[1])
+ generator.group1.add().field1.MergeFrom(messages[2])
+ generator.group2.add().field1.MergeFrom(messages[0])
+ generator.group2.add().field1.MergeFrom(messages[1])
+ generator.group2.add().field1.MergeFrom(messages[2])
+
+ data = generator.SerializeToString()
+ parsing_merge = unittest_pb2.TestParsingMerge()
+ parsing_merge.ParseFromString(data)
+
+ # Required and optional fields should be merged.
+ self.assertEqual(parsing_merge.required_all_types, merged_message)
+ self.assertEqual(parsing_merge.optional_all_types, merged_message)
+ self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types,
+ merged_message)
+ self.assertEqual(parsing_merge.Extensions[
+ unittest_pb2.TestParsingMerge.optional_ext],
+ merged_message)
+
+ # Repeated fields should not be merged.
+ self.assertEqual(len(parsing_merge.repeated_all_types), 3)
+ self.assertEqual(len(parsing_merge.repeatedgroup), 3)
+ self.assertEqual(len(parsing_merge.Extensions[
+ unittest_pb2.TestParsingMerge.repeated_ext]), 3)
+
+
+# Class to test proto3-only features/behavior (updated field presence & enums)
+class Proto3Test(basetest.TestCase):
+
+ def testFieldPresence(self):
+ message = unittest_proto3_arena_pb2.TestAllTypes()
+
+ # We can't test presence of non-repeated, non-submessage fields.
+ with self.assertRaises(ValueError):
+ message.HasField("optional_int32")
+ with self.assertRaises(ValueError):
+ message.HasField("optional_float")
+ with self.assertRaises(ValueError):
+ message.HasField("optional_string")
+ with self.assertRaises(ValueError):
+ message.HasField("optional_bool")
+
+ # But we can still test presence of submessage fields.
+ self.assertFalse(message.HasField("optional_nested_message"))
+
+ # As with proto2, we can't test presence of fields that don't exist, or
+ # repeated fields.
+ with self.assertRaises(ValueError):
+ message.HasField("field_doesnt_exist")
+
+ with self.assertRaises(ValueError):
+ message.HasField("repeated_int32")
+ with self.assertRaises(ValueError):
+ message.HasField("repeated_nested_message")
+
+ # Fields should default to their type-specific default.
+ self.assertEqual(0, message.optional_int32)
+ self.assertEqual(0, message.optional_float)
+ self.assertEqual("", message.optional_string)
+ self.assertEqual(False, message.optional_bool)
+ self.assertEqual(0, message.optional_nested_message.bb)
+
+ # Setting a submessage should still return proper presence information.
+ message.optional_nested_message.bb = 0
+ self.assertTrue(message.HasField("optional_nested_message"))
+
+ # Set the fields to non-default values.
+ message.optional_int32 = 5
+ message.optional_float = 1.1
+ message.optional_string = "abc"
+ message.optional_bool = True
+ message.optional_nested_message.bb = 15
+
+ # Clearing the fields unsets them and resets their value to default.
+ message.ClearField("optional_int32")
+ message.ClearField("optional_float")
+ message.ClearField("optional_string")
+ message.ClearField("optional_bool")
+ message.ClearField("optional_nested_message")
+
+ self.assertEqual(0, message.optional_int32)
+ self.assertEqual(0, message.optional_float)
+ self.assertEqual("", message.optional_string)
+ self.assertEqual(False, message.optional_bool)
+ self.assertEqual(0, message.optional_nested_message.bb)
+
+ def testAssignUnknownEnum(self):
+ """Assigning an unknown enum value is allowed and preserves the value."""
+ m = unittest_proto3_arena_pb2.TestAllTypes()
+
+ m.optional_nested_enum = 1234567
+ self.assertEqual(1234567, m.optional_nested_enum)
+ m.repeated_nested_enum.append(22334455)
+ self.assertEqual(22334455, m.repeated_nested_enum[0])
+ # Assignment is a different code path than append for the C++ impl.
+ m.repeated_nested_enum[0] = 7654321
+ self.assertEqual(7654321, m.repeated_nested_enum[0])
+ serialized = m.SerializeToString()
+
+ m2 = unittest_proto3_arena_pb2.TestAllTypes()
+ m2.ParseFromString(serialized)
+ self.assertEqual(1234567, m2.optional_nested_enum)
+ self.assertEqual(7654321, m2.repeated_nested_enum[0])
+
class ValidTypeNamesTest(basetest.TestCase):
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py
index 6fda6ae0..6ad0f90d 100755
--- a/python/google/protobuf/internal/python_message.py
+++ b/python/google/protobuf/internal/python_message.py
@@ -219,12 +219,20 @@ def _AttachFieldHelpers(cls, field_descriptor):
def AddDecoder(wiretype, is_packed):
tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
- cls._decoders_by_tag[tag_bytes] = (
- type_checkers.TYPE_TO_DECODER[field_descriptor.type](
- field_descriptor.number, is_repeated, is_packed,
- field_descriptor, field_descriptor._default_constructor),
- field_descriptor if field_descriptor.containing_oneof is not None
- else None)
+ decode_type = field_descriptor.type
+ if (decode_type == _FieldDescriptor.TYPE_ENUM and
+ type_checkers.SupportsOpenEnums(field_descriptor)):
+ decode_type = _FieldDescriptor.TYPE_INT32
+
+ oneof_descriptor = None
+ if field_descriptor.containing_oneof is not None:
+ oneof_descriptor = field_descriptor
+
+ field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
+ field_descriptor.number, is_repeated, is_packed,
+ field_descriptor, field_descriptor._default_constructor)
+
+ cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor)
AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
False)
@@ -296,6 +304,8 @@ def _DefaultValueConstructorForField(field):
def MakeSubMessageDefault(message):
result = message_type._concrete_class()
result._SetListener(message._listener_for_children)
+ if field.containing_oneof:
+ message._UpdateOneofState(field)
return result
return MakeSubMessageDefault
@@ -476,6 +486,7 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls):
type_checker = type_checkers.GetTypeChecker(field)
default_value = field.default_value
valid_values = set()
+ is_proto3 = field.containing_type.syntax == "proto3"
def getter(self):
# TODO(protobuf-team): This may be broken since there may not be
@@ -483,15 +494,24 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls):
return self._fields.get(field, default_value)
getter.__module__ = None
getter.__doc__ = 'Getter for %s.' % proto_field_name
+
+ clear_when_set_to_default = is_proto3 and not field.containing_oneof
+
def field_setter(self, new_value):
# pylint: disable=protected-access
- self._fields[field] = type_checker.CheckValue(new_value)
+ # Testing the value for truthiness captures all of the proto3 defaults
+ # (0, 0.0, enum 0, and False).
+ new_value = type_checker.CheckValue(new_value)
+ if clear_when_set_to_default and not new_value:
+ self._fields.pop(field, None)
+ else:
+ self._fields[field] = new_value
# Check _cached_byte_size_dirty inline to improve performance, since scalar
# setters are called frequently.
if not self._cached_byte_size_dirty:
self._Modified()
- if field.containing_oneof is not None:
+ if field.containing_oneof:
def setter(self, new_value):
field_setter(self, new_value)
self._UpdateOneofState(field)
@@ -624,24 +644,35 @@ def _AddListFieldsMethod(message_descriptor, cls):
cls.ListFields = ListFields
+_Proto3HasError = 'Protocol message has no non-repeated submessage field "%s"'
+_Proto2HasError = 'Protocol message has no non-repeated field "%s"'
def _AddHasFieldMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
- singular_fields = {}
+ is_proto3 = (message_descriptor.syntax == "proto3")
+ error_msg = _Proto3HasError if is_proto3 else _Proto2HasError
+
+ hassable_fields = {}
for field in message_descriptor.fields:
- if field.label != _FieldDescriptor.LABEL_REPEATED:
- singular_fields[field.name] = field
- # Fields inside oneofs are never repeated (enforced by the compiler).
- for field in message_descriptor.oneofs:
- singular_fields[field.name] = field
+ if field.label == _FieldDescriptor.LABEL_REPEATED:
+ continue
+ # For proto3, only submessages and fields inside a oneof have presence.
+ if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and
+ not field.containing_oneof):
+ continue
+ hassable_fields[field.name] = field
+
+ if not is_proto3:
+ # Fields inside oneofs are never repeated (enforced by the compiler).
+ for oneof in message_descriptor.oneofs:
+ hassable_fields[oneof.name] = oneof
def HasField(self, field_name):
try:
- field = singular_fields[field_name]
+ field = hassable_fields[field_name]
except KeyError:
- raise ValueError(
- 'Protocol message has no singular "%s" field.' % field_name)
+ raise ValueError(error_msg % field_name)
if isinstance(field, descriptor_mod.OneofDescriptor):
try:
@@ -871,6 +902,7 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
local_ReadTag = decoder.ReadTag
local_SkipField = decoder.SkipField
decoders_by_tag = cls._decoders_by_tag
+ is_proto3 = message_descriptor.syntax == "proto3"
def InternalParse(self, buffer, pos, end):
self._Modified()
@@ -884,9 +916,11 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
if new_pos == -1:
return pos
- if not unknown_field_list:
- unknown_field_list = self._unknown_fields = []
- unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos]))
+ if not is_proto3:
+ if not unknown_field_list:
+ unknown_field_list = self._unknown_fields = []
+ unknown_field_list.append(
+ (tag_bytes, buffer[value_start_pos:new_pos]))
pos = new_pos
else:
pos = field_decoder(buffer, new_pos, end, self, field_dict)
@@ -1008,6 +1042,8 @@ def _AddMergeFromMethod(cls):
# Construct a new object to represent this field.
field_value = field._default_constructor(self)
fields[field] = field_value
+ if field.containing_oneof:
+ self._UpdateOneofState(field)
field_value.MergeFrom(value)
else:
self._fields[field] = value
@@ -1252,11 +1288,10 @@ class _ExtensionDict(object):
# It's slightly wasteful to lookup the type checker each time,
# but we expect this to be a vanishingly uncommon case anyway.
- type_checker = type_checkers.GetTypeChecker(
- extension_handle)
+ type_checker = type_checkers.GetTypeChecker(extension_handle)
# pylint: disable=protected-access
self._extended_message._fields[extension_handle] = (
- type_checker.CheckValue(value))
+ type_checker.CheckValue(value))
self._extended_message._Modified()
def _FindExtensionByName(self, name):
diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py
index 6b24b092..8ac76c63 100755
--- a/python/google/protobuf/internal/reflection_test.py
+++ b/python/google/protobuf/internal/reflection_test.py
@@ -1792,6 +1792,27 @@ class ReflectionTest(basetest.TestCase):
# Just check the default value.
self.assertEqual(57, msg.inner.value)
+ @basetest.unittest.skipIf(
+ api_implementation.Type() != 'cpp' or api_implementation.Version() != 2,
+ 'CPPv2-specific test')
+ def testBadArguments(self):
+ # Some of these assertions used to segfault.
+ from google.protobuf.pyext import _message
+ self.assertRaises(TypeError, _message.Message._GetFieldDescriptor, 3)
+ self.assertRaises(TypeError, _message.Message._GetExtensionDescriptor, 42)
+ self.assertRaises(TypeError,
+ unittest_pb2.TestAllTypes().__getattribute__, 42)
+
+ @basetest.unittest.skipIf(
+ api_implementation.Type() != 'cpp' or api_implementation.Version() != 2,
+ 'CPPv2-specific test')
+ def testRosyHack(self):
+ from google.protobuf.pyext import _message
+ from google3.gdata.rosy.proto import core_api2_pb2
+ from google3.gdata.rosy.proto import core_pb2
+ self.assertEqual(_message.Message, core_pb2.PageSelection.__base__)
+ self.assertEqual(_message.Message, core_api2_pb2.PageSelection.__base__)
+
# Since we had so many tests for protocol buffer equality, we broke these out
# into separate TestCase classes.
diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py
index 787f4650..d84e3836 100755
--- a/python/google/protobuf/internal/test_util.py
+++ b/python/google/protobuf/internal/test_util.py
@@ -40,13 +40,19 @@ import os.path
from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_pb2
+from google.protobuf import descriptor_pb2
+# Tests whether the given TestAllTypes message is proto2 or not.
+# This is used to gate several fields/features that only exist
+# for the proto2 version of the message.
+def IsProto2(message):
+ return message.DESCRIPTOR.syntax == "proto2"
def SetAllNonLazyFields(message):
"""Sets every non-lazy field in the message to a unique value.
Args:
- message: A unittest_pb2.TestAllTypes instance.
+ message: A TestAllTypes instance.
"""
#
@@ -77,7 +83,8 @@ def SetAllNonLazyFields(message):
message.optional_nested_enum = unittest_pb2.TestAllTypes.BAZ
message.optional_foreign_enum = unittest_pb2.FOREIGN_BAZ
- message.optional_import_enum = unittest_import_pb2.IMPORT_BAZ
+ if IsProto2(message):
+ message.optional_import_enum = unittest_import_pb2.IMPORT_BAZ
message.optional_string_piece = u'124'
message.optional_cord = u'125'
@@ -110,7 +117,8 @@ def SetAllNonLazyFields(message):
message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAR)
message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAR)
- message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAR)
+ if IsProto2(message):
+ message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAR)
message.repeated_string_piece.append(u'224')
message.repeated_cord.append(u'225')
@@ -140,7 +148,8 @@ def SetAllNonLazyFields(message):
message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAZ)
message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAZ)
- message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAZ)
+ if IsProto2(message):
+ message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAZ)
message.repeated_string_piece.append(u'324')
message.repeated_cord.append(u'325')
@@ -149,28 +158,29 @@ def SetAllNonLazyFields(message):
# Fields that have defaults.
#
- message.default_int32 = 401
- message.default_int64 = 402
- message.default_uint32 = 403
- message.default_uint64 = 404
- message.default_sint32 = 405
- message.default_sint64 = 406
- message.default_fixed32 = 407
- message.default_fixed64 = 408
- message.default_sfixed32 = 409
- message.default_sfixed64 = 410
- message.default_float = 411
- message.default_double = 412
- message.default_bool = False
- message.default_string = '415'
- message.default_bytes = b'416'
-
- message.default_nested_enum = unittest_pb2.TestAllTypes.FOO
- message.default_foreign_enum = unittest_pb2.FOREIGN_FOO
- message.default_import_enum = unittest_import_pb2.IMPORT_FOO
-
- message.default_string_piece = '424'
- message.default_cord = '425'
+ if IsProto2(message):
+ message.default_int32 = 401
+ message.default_int64 = 402
+ message.default_uint32 = 403
+ message.default_uint64 = 404
+ message.default_sint32 = 405
+ message.default_sint64 = 406
+ message.default_fixed32 = 407
+ message.default_fixed64 = 408
+ message.default_sfixed32 = 409
+ message.default_sfixed64 = 410
+ message.default_float = 411
+ message.default_double = 412
+ message.default_bool = False
+ message.default_string = '415'
+ message.default_bytes = b'416'
+
+ message.default_nested_enum = unittest_pb2.TestAllTypes.FOO
+ message.default_foreign_enum = unittest_pb2.FOREIGN_FOO
+ message.default_import_enum = unittest_import_pb2.IMPORT_FOO
+
+ message.default_string_piece = '424'
+ message.default_cord = '425'
message.oneof_uint32 = 601
message.oneof_nested_message.bb = 602
@@ -398,7 +408,8 @@ def ExpectAllFieldsSet(test_case, message):
test_case.assertTrue(message.HasField('optional_nested_enum'))
test_case.assertTrue(message.HasField('optional_foreign_enum'))
- test_case.assertTrue(message.HasField('optional_import_enum'))
+ if IsProto2(message):
+ test_case.assertTrue(message.HasField('optional_import_enum'))
test_case.assertTrue(message.HasField('optional_string_piece'))
test_case.assertTrue(message.HasField('optional_cord'))
@@ -430,8 +441,9 @@ def ExpectAllFieldsSet(test_case, message):
message.optional_nested_enum)
test_case.assertEqual(unittest_pb2.FOREIGN_BAZ,
message.optional_foreign_enum)
- test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ,
- message.optional_import_enum)
+ if IsProto2(message):
+ test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ,
+ message.optional_import_enum)
# -----------------------------------------------------------------
@@ -457,7 +469,8 @@ def ExpectAllFieldsSet(test_case, message):
test_case.assertEqual(2, len(message.repeated_import_message))
test_case.assertEqual(2, len(message.repeated_nested_enum))
test_case.assertEqual(2, len(message.repeated_foreign_enum))
- test_case.assertEqual(2, len(message.repeated_import_enum))
+ if IsProto2(message):
+ test_case.assertEqual(2, len(message.repeated_import_enum))
test_case.assertEqual(2, len(message.repeated_string_piece))
test_case.assertEqual(2, len(message.repeated_cord))
@@ -488,8 +501,9 @@ def ExpectAllFieldsSet(test_case, message):
message.repeated_nested_enum[0])
test_case.assertEqual(unittest_pb2.FOREIGN_BAR,
message.repeated_foreign_enum[0])
- test_case.assertEqual(unittest_import_pb2.IMPORT_BAR,
- message.repeated_import_enum[0])
+ if IsProto2(message):
+ test_case.assertEqual(unittest_import_pb2.IMPORT_BAR,
+ message.repeated_import_enum[0])
test_case.assertEqual(301, message.repeated_int32[1])
test_case.assertEqual(302, message.repeated_int64[1])
@@ -517,53 +531,55 @@ def ExpectAllFieldsSet(test_case, message):
message.repeated_nested_enum[1])
test_case.assertEqual(unittest_pb2.FOREIGN_BAZ,
message.repeated_foreign_enum[1])
- test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ,
- message.repeated_import_enum[1])
+ if IsProto2(message):
+ test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ,
+ message.repeated_import_enum[1])
# -----------------------------------------------------------------
- test_case.assertTrue(message.HasField('default_int32'))
- test_case.assertTrue(message.HasField('default_int64'))
- test_case.assertTrue(message.HasField('default_uint32'))
- test_case.assertTrue(message.HasField('default_uint64'))
- test_case.assertTrue(message.HasField('default_sint32'))
- test_case.assertTrue(message.HasField('default_sint64'))
- test_case.assertTrue(message.HasField('default_fixed32'))
- test_case.assertTrue(message.HasField('default_fixed64'))
- test_case.assertTrue(message.HasField('default_sfixed32'))
- test_case.assertTrue(message.HasField('default_sfixed64'))
- test_case.assertTrue(message.HasField('default_float'))
- test_case.assertTrue(message.HasField('default_double'))
- test_case.assertTrue(message.HasField('default_bool'))
- test_case.assertTrue(message.HasField('default_string'))
- test_case.assertTrue(message.HasField('default_bytes'))
-
- test_case.assertTrue(message.HasField('default_nested_enum'))
- test_case.assertTrue(message.HasField('default_foreign_enum'))
- test_case.assertTrue(message.HasField('default_import_enum'))
-
- test_case.assertEqual(401, message.default_int32)
- test_case.assertEqual(402, message.default_int64)
- test_case.assertEqual(403, message.default_uint32)
- test_case.assertEqual(404, message.default_uint64)
- test_case.assertEqual(405, message.default_sint32)
- test_case.assertEqual(406, message.default_sint64)
- test_case.assertEqual(407, message.default_fixed32)
- test_case.assertEqual(408, message.default_fixed64)
- test_case.assertEqual(409, message.default_sfixed32)
- test_case.assertEqual(410, message.default_sfixed64)
- test_case.assertEqual(411, message.default_float)
- test_case.assertEqual(412, message.default_double)
- test_case.assertEqual(False, message.default_bool)
- test_case.assertEqual('415', message.default_string)
- test_case.assertEqual(b'416', message.default_bytes)
-
- test_case.assertEqual(unittest_pb2.TestAllTypes.FOO,
- message.default_nested_enum)
- test_case.assertEqual(unittest_pb2.FOREIGN_FOO,
- message.default_foreign_enum)
- test_case.assertEqual(unittest_import_pb2.IMPORT_FOO,
- message.default_import_enum)
+ if IsProto2(message):
+ test_case.assertTrue(message.HasField('default_int32'))
+ test_case.assertTrue(message.HasField('default_int64'))
+ test_case.assertTrue(message.HasField('default_uint32'))
+ test_case.assertTrue(message.HasField('default_uint64'))
+ test_case.assertTrue(message.HasField('default_sint32'))
+ test_case.assertTrue(message.HasField('default_sint64'))
+ test_case.assertTrue(message.HasField('default_fixed32'))
+ test_case.assertTrue(message.HasField('default_fixed64'))
+ test_case.assertTrue(message.HasField('default_sfixed32'))
+ test_case.assertTrue(message.HasField('default_sfixed64'))
+ test_case.assertTrue(message.HasField('default_float'))
+ test_case.assertTrue(message.HasField('default_double'))
+ test_case.assertTrue(message.HasField('default_bool'))
+ test_case.assertTrue(message.HasField('default_string'))
+ test_case.assertTrue(message.HasField('default_bytes'))
+
+ test_case.assertTrue(message.HasField('default_nested_enum'))
+ test_case.assertTrue(message.HasField('default_foreign_enum'))
+ test_case.assertTrue(message.HasField('default_import_enum'))
+
+ test_case.assertEqual(401, message.default_int32)
+ test_case.assertEqual(402, message.default_int64)
+ test_case.assertEqual(403, message.default_uint32)
+ test_case.assertEqual(404, message.default_uint64)
+ test_case.assertEqual(405, message.default_sint32)
+ test_case.assertEqual(406, message.default_sint64)
+ test_case.assertEqual(407, message.default_fixed32)
+ test_case.assertEqual(408, message.default_fixed64)
+ test_case.assertEqual(409, message.default_sfixed32)
+ test_case.assertEqual(410, message.default_sfixed64)
+ test_case.assertEqual(411, message.default_float)
+ test_case.assertEqual(412, message.default_double)
+ test_case.assertEqual(False, message.default_bool)
+ test_case.assertEqual('415', message.default_string)
+ test_case.assertEqual(b'416', message.default_bytes)
+
+ test_case.assertEqual(unittest_pb2.TestAllTypes.FOO,
+ message.default_nested_enum)
+ test_case.assertEqual(unittest_pb2.FOREIGN_FOO,
+ message.default_foreign_enum)
+ test_case.assertEqual(unittest_import_pb2.IMPORT_FOO,
+ message.default_import_enum)
def GoldenFile(filename):
@@ -594,7 +610,7 @@ def SetAllPackedFields(message):
"""Sets every field in the message to a unique value.
Args:
- message: A unittest_pb2.TestPackedTypes instance.
+ message: A TestPackedTypes instance.
"""
message.packed_int32.extend([601, 701])
message.packed_int64.extend([602, 702])
diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py
index 55e3c2c8..68ab9659 100755
--- a/python/google/protobuf/internal/text_format_test.py
+++ b/python/google/protobuf/internal/text_format_test.py
@@ -37,13 +37,17 @@ __author__ = 'kenton@google.com (Kenton Varda)'
import re
from google.apputils import basetest
-from google.protobuf import text_format
+from google.apputils.pybase import parameterized
+
+from google.protobuf import unittest_mset_pb2
+from google.protobuf import unittest_pb2
+from google.protobuf import unittest_proto3_arena_pb2
from google.protobuf.internal import api_implementation
from google.protobuf.internal import test_util
-from google.protobuf import unittest_pb2
-from google.protobuf import unittest_mset_pb2
+from google.protobuf import text_format
-class TextFormatTest(basetest.TestCase):
+# Base class with some common functionality.
+class TextFormatBase(basetest.TestCase):
def ReadGolden(self, golden_filename):
with test_util.GoldenFile(golden_filename) as f:
@@ -57,73 +61,24 @@ class TextFormatTest(basetest.TestCase):
def CompareToGoldenText(self, text, golden_text):
self.assertMultiLineEqual(text, golden_text)
- def testPrintAllFields(self):
- message = unittest_pb2.TestAllTypes()
- test_util.SetAllFields(message)
- self.CompareToGoldenFile(
- self.RemoveRedundantZeros(text_format.MessageToString(message)),
- 'text_format_unittest_data_oneof_implemented.txt')
-
- def testPrintInIndexOrder(self):
- message = unittest_pb2.TestFieldOrderings()
- message.my_string = '115'
- message.my_int = 101
- message.my_float = 111
- message.optional_nested_message.oo = 0
- message.optional_nested_message.bb = 1
- self.CompareToGoldenText(
- self.RemoveRedundantZeros(text_format.MessageToString(
- message, use_index_order=True)),
- 'my_string: \"115\"\nmy_int: 101\nmy_float: 111\n'
- 'optional_nested_message {\n oo: 0\n bb: 1\n}\n')
- self.CompareToGoldenText(
- self.RemoveRedundantZeros(text_format.MessageToString(
- message)),
- 'my_int: 101\nmy_string: \"115\"\nmy_float: 111\n'
- 'optional_nested_message {\n bb: 1\n oo: 0\n}\n')
-
- def testPrintAllExtensions(self):
- message = unittest_pb2.TestAllExtensions()
- test_util.SetAllExtensions(message)
- self.CompareToGoldenFile(
- self.RemoveRedundantZeros(text_format.MessageToString(message)),
- 'text_format_unittest_extensions_data.txt')
-
- def testPrintAllFieldsPointy(self):
- message = unittest_pb2.TestAllTypes()
- test_util.SetAllFields(message)
- self.CompareToGoldenFile(
- self.RemoveRedundantZeros(
- text_format.MessageToString(message, pointy_brackets=True)),
- 'text_format_unittest_data_pointy_oneof.txt')
+ def RemoveRedundantZeros(self, text):
+ # Some platforms print 1e+5 as 1e+005. This is fine, but we need to remove
+ # these zeros in order to match the golden file.
+ text = text.replace('e+0','e+').replace('e+0','e+') \
+ .replace('e-0','e-').replace('e-0','e-')
+ # Floating point fields are printed with .0 suffix even if they are
+ # actualy integer numbers.
+ text = re.compile('\.0$', re.MULTILINE).sub('', text)
+ return text
- def testPrintAllExtensionsPointy(self):
- message = unittest_pb2.TestAllExtensions()
- test_util.SetAllExtensions(message)
- self.CompareToGoldenFile(
- self.RemoveRedundantZeros(text_format.MessageToString(
- message, pointy_brackets=True)),
- 'text_format_unittest_extensions_data_pointy.txt')
- def testPrintMessageSet(self):
- message = unittest_mset_pb2.TestMessageSetContainer()
- ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
- ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
- message.message_set.Extensions[ext1].i = 23
- message.message_set.Extensions[ext2].str = 'foo'
- self.CompareToGoldenText(
- text_format.MessageToString(message),
- 'message_set {\n'
- ' [protobuf_unittest.TestMessageSetExtension1] {\n'
- ' i: 23\n'
- ' }\n'
- ' [protobuf_unittest.TestMessageSetExtension2] {\n'
- ' str: \"foo\"\n'
- ' }\n'
- '}\n')
+@parameterized.Parameters(
+ (unittest_pb2),
+ (unittest_proto3_arena_pb2))
+class TextFormatTest(TextFormatBase):
- def testPrintExotic(self):
- message = unittest_pb2.TestAllTypes()
+ def testPrintExotic(self, message_module):
+ message = message_module.TestAllTypes()
message.repeated_int64.append(-9223372036854775808)
message.repeated_uint64.append(18446744073709551615)
message.repeated_double.append(123.456)
@@ -142,61 +97,44 @@ class TextFormatTest(basetest.TestCase):
' "\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""\n'
'repeated_string: "\\303\\274\\352\\234\\237"\n')
- def testPrintExoticUnicodeSubclass(self):
+ def testPrintExoticUnicodeSubclass(self, message_module):
class UnicodeSub(unicode):
pass
- message = unittest_pb2.TestAllTypes()
+ message = message_module.TestAllTypes()
message.repeated_string.append(UnicodeSub(u'\u00fc\ua71f'))
self.CompareToGoldenText(
text_format.MessageToString(message),
'repeated_string: "\\303\\274\\352\\234\\237"\n')
- def testPrintNestedMessageAsOneLine(self):
- message = unittest_pb2.TestAllTypes()
+ def testPrintNestedMessageAsOneLine(self, message_module):
+ message = message_module.TestAllTypes()
msg = message.repeated_nested_message.add()
msg.bb = 42
self.CompareToGoldenText(
text_format.MessageToString(message, as_one_line=True),
'repeated_nested_message { bb: 42 }')
- def testPrintRepeatedFieldsAsOneLine(self):
- message = unittest_pb2.TestAllTypes()
+ def testPrintRepeatedFieldsAsOneLine(self, message_module):
+ message = message_module.TestAllTypes()
message.repeated_int32.append(1)
message.repeated_int32.append(1)
message.repeated_int32.append(3)
- message.repeated_string.append("Google")
- message.repeated_string.append("Zurich")
+ message.repeated_string.append('Google')
+ message.repeated_string.append('Zurich')
self.CompareToGoldenText(
text_format.MessageToString(message, as_one_line=True),
'repeated_int32: 1 repeated_int32: 1 repeated_int32: 3 '
'repeated_string: "Google" repeated_string: "Zurich"')
- def testPrintNestedNewLineInStringAsOneLine(self):
- message = unittest_pb2.TestAllTypes()
- message.optional_string = "a\nnew\nline"
+ def testPrintNestedNewLineInStringAsOneLine(self, message_module):
+ message = message_module.TestAllTypes()
+ message.optional_string = 'a\nnew\nline'
self.CompareToGoldenText(
text_format.MessageToString(message, as_one_line=True),
'optional_string: "a\\nnew\\nline"')
- def testPrintMessageSetAsOneLine(self):
- message = unittest_mset_pb2.TestMessageSetContainer()
- ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
- ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
- message.message_set.Extensions[ext1].i = 23
- message.message_set.Extensions[ext2].str = 'foo'
- self.CompareToGoldenText(
- text_format.MessageToString(message, as_one_line=True),
- 'message_set {'
- ' [protobuf_unittest.TestMessageSetExtension1] {'
- ' i: 23'
- ' }'
- ' [protobuf_unittest.TestMessageSetExtension2] {'
- ' str: \"foo\"'
- ' }'
- ' }')
-
- def testPrintExoticAsOneLine(self):
- message = unittest_pb2.TestAllTypes()
+ def testPrintExoticAsOneLine(self, message_module):
+ message = message_module.TestAllTypes()
message.repeated_int64.append(-9223372036854775808)
message.repeated_uint64.append(18446744073709551615)
message.repeated_double.append(123.456)
@@ -216,8 +154,8 @@ class TextFormatTest(basetest.TestCase):
'"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""'
' repeated_string: "\\303\\274\\352\\234\\237"')
- def testRoundTripExoticAsOneLine(self):
- message = unittest_pb2.TestAllTypes()
+ def testRoundTripExoticAsOneLine(self, message_module):
+ message = message_module.TestAllTypes()
message.repeated_int64.append(-9223372036854775808)
message.repeated_uint64.append(18446744073709551615)
message.repeated_double.append(123.456)
@@ -229,7 +167,7 @@ class TextFormatTest(basetest.TestCase):
# Test as_utf8 = False.
wire_text = text_format.MessageToString(
message, as_one_line=True, as_utf8=False)
- parsed_message = unittest_pb2.TestAllTypes()
+ parsed_message = message_module.TestAllTypes()
r = text_format.Parse(wire_text, parsed_message)
self.assertIs(r, parsed_message)
self.assertEquals(message, parsed_message)
@@ -237,25 +175,25 @@ class TextFormatTest(basetest.TestCase):
# Test as_utf8 = True.
wire_text = text_format.MessageToString(
message, as_one_line=True, as_utf8=True)
- parsed_message = unittest_pb2.TestAllTypes()
+ parsed_message = message_module.TestAllTypes()
r = text_format.Parse(wire_text, parsed_message)
self.assertIs(r, parsed_message)
self.assertEquals(message, parsed_message,
'\n%s != %s' % (message, parsed_message))
- def testPrintRawUtf8String(self):
- message = unittest_pb2.TestAllTypes()
+ def testPrintRawUtf8String(self, message_module):
+ message = message_module.TestAllTypes()
message.repeated_string.append(u'\u00fc\ua71f')
text = text_format.MessageToString(message, as_utf8=True)
self.CompareToGoldenText(text, 'repeated_string: "\303\274\352\234\237"\n')
- parsed_message = unittest_pb2.TestAllTypes()
+ parsed_message = message_module.TestAllTypes()
text_format.Parse(text, parsed_message)
self.assertEquals(message, parsed_message,
'\n%s != %s' % (message, parsed_message))
- def testPrintFloatFormat(self):
+ def testPrintFloatFormat(self, message_module):
# Check that float_format argument is passed to sub-message formatting.
- message = unittest_pb2.NestedTestAllTypes()
+ message = message_module.NestedTestAllTypes()
# We use 1.25 as it is a round number in binary. The proto 32-bit float
# will not gain additional imprecise digits as a 64-bit Python float and
# show up in its str. 32-bit 1.2 is noisy when extended to 64-bit:
@@ -285,85 +223,24 @@ class TextFormatTest(basetest.TestCase):
self.RemoveRedundantZeros(text_message),
'payload {{ {} {} {} {} }}'.format(*formatted_fields))
- def testMessageToString(self):
- message = unittest_pb2.ForeignMessage()
+ def testMessageToString(self, message_module):
+ message = message_module.ForeignMessage()
message.c = 123
self.assertEqual('c: 123\n', str(message))
- def RemoveRedundantZeros(self, text):
- # Some platforms print 1e+5 as 1e+005. This is fine, but we need to remove
- # these zeros in order to match the golden file.
- text = text.replace('e+0','e+').replace('e+0','e+') \
- .replace('e-0','e-').replace('e-0','e-')
- # Floating point fields are printed with .0 suffix even if they are
- # actualy integer numbers.
- text = re.compile('\.0$', re.MULTILINE).sub('', text)
- return text
-
- def testParseGolden(self):
- golden_text = '\n'.join(self.ReadGolden('text_format_unittest_data.txt'))
- parsed_message = unittest_pb2.TestAllTypes()
- r = text_format.Parse(golden_text, parsed_message)
- self.assertIs(r, parsed_message)
-
- message = unittest_pb2.TestAllTypes()
- test_util.SetAllFields(message)
- self.assertEquals(message, parsed_message)
-
- def testParseGoldenExtensions(self):
- golden_text = '\n'.join(self.ReadGolden(
- 'text_format_unittest_extensions_data.txt'))
- parsed_message = unittest_pb2.TestAllExtensions()
- text_format.Parse(golden_text, parsed_message)
-
- message = unittest_pb2.TestAllExtensions()
- test_util.SetAllExtensions(message)
- self.assertEquals(message, parsed_message)
-
- def testParseAllFields(self):
- message = unittest_pb2.TestAllTypes()
+ def testParseAllFields(self, message_module):
+ message = message_module.TestAllTypes()
test_util.SetAllFields(message)
ascii_text = text_format.MessageToString(message)
- parsed_message = unittest_pb2.TestAllTypes()
- text_format.Parse(ascii_text, parsed_message)
- self.assertEqual(message, parsed_message)
- test_util.ExpectAllFieldsSet(self, message)
-
- def testParseAllExtensions(self):
- message = unittest_pb2.TestAllExtensions()
- test_util.SetAllExtensions(message)
- ascii_text = text_format.MessageToString(message)
-
- parsed_message = unittest_pb2.TestAllExtensions()
+ parsed_message = message_module.TestAllTypes()
text_format.Parse(ascii_text, parsed_message)
self.assertEqual(message, parsed_message)
+ if message_module is unittest_pb2:
+ test_util.ExpectAllFieldsSet(self, message)
- def testParseMessageSet(self):
- message = unittest_pb2.TestAllTypes()
- text = ('repeated_uint64: 1\n'
- 'repeated_uint64: 2\n')
- text_format.Parse(text, message)
- self.assertEqual(1, message.repeated_uint64[0])
- self.assertEqual(2, message.repeated_uint64[1])
-
- message = unittest_mset_pb2.TestMessageSetContainer()
- text = ('message_set {\n'
- ' [protobuf_unittest.TestMessageSetExtension1] {\n'
- ' i: 23\n'
- ' }\n'
- ' [protobuf_unittest.TestMessageSetExtension2] {\n'
- ' str: \"foo\"\n'
- ' }\n'
- '}\n')
- text_format.Parse(text, message)
- ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
- ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
- self.assertEquals(23, message.message_set.Extensions[ext1].i)
- self.assertEquals('foo', message.message_set.Extensions[ext2].str)
-
- def testParseExotic(self):
- message = unittest_pb2.TestAllTypes()
+ def testParseExotic(self, message_module):
+ message = message_module.TestAllTypes()
text = ('repeated_int64: -9223372036854775808\n'
'repeated_uint64: 18446744073709551615\n'
'repeated_double: 123.456\n'
@@ -388,8 +265,8 @@ class TextFormatTest(basetest.TestCase):
self.assertEqual(u'\u00fc\ua71f', message.repeated_string[2])
self.assertEqual(u'\u00fc', message.repeated_string[3])
- def testParseTrailingCommas(self):
- message = unittest_pb2.TestAllTypes()
+ def testParseTrailingCommas(self, message_module):
+ message = message_module.TestAllTypes()
text = ('repeated_int64: 100;\n'
'repeated_int64: 200;\n'
'repeated_int64: 300,\n'
@@ -403,51 +280,37 @@ class TextFormatTest(basetest.TestCase):
self.assertEqual(u'one', message.repeated_string[0])
self.assertEqual(u'two', message.repeated_string[1])
- def testParseEmptyText(self):
- message = unittest_pb2.TestAllTypes()
+ def testParseEmptyText(self, message_module):
+ message = message_module.TestAllTypes()
text = ''
text_format.Parse(text, message)
- self.assertEquals(unittest_pb2.TestAllTypes(), message)
+ self.assertEquals(message_module.TestAllTypes(), message)
- def testParseInvalidUtf8(self):
- message = unittest_pb2.TestAllTypes()
+ def testParseInvalidUtf8(self, message_module):
+ message = message_module.TestAllTypes()
text = 'repeated_string: "\\xc3\\xc3"'
self.assertRaises(text_format.ParseError, text_format.Parse, text, message)
- def testParseSingleWord(self):
- message = unittest_pb2.TestAllTypes()
+ def testParseSingleWord(self, message_module):
+ message = message_module.TestAllTypes()
text = 'foo'
- self.assertRaisesWithLiteralMatch(
+ self.assertRaisesRegexp(
text_format.ParseError,
- ('1:1 : Message type "protobuf_unittest.TestAllTypes" has no field named '
- '"foo".'),
+ (r'1:1 : Message type "\w+.TestAllTypes" has no field named '
+ r'"foo".'),
text_format.Parse, text, message)
- def testParseUnknownField(self):
- message = unittest_pb2.TestAllTypes()
+ def testParseUnknownField(self, message_module):
+ message = message_module.TestAllTypes()
text = 'unknown_field: 8\n'
- self.assertRaisesWithLiteralMatch(
+ self.assertRaisesRegexp(
text_format.ParseError,
- ('1:1 : Message type "protobuf_unittest.TestAllTypes" has no field named '
- '"unknown_field".'),
+ (r'1:1 : Message type "\w+.TestAllTypes" has no field named '
+ r'"unknown_field".'),
text_format.Parse, text, message)
- def testParseBadExtension(self):
- message = unittest_pb2.TestAllExtensions()
- text = '[unknown_extension]: 8\n'
- self.assertRaisesWithLiteralMatch(
- text_format.ParseError,
- '1:2 : Extension "unknown_extension" not registered.',
- text_format.Parse, text, message)
- message = unittest_pb2.TestAllTypes()
- self.assertRaisesWithLiteralMatch(
- text_format.ParseError,
- ('1:2 : Message type "protobuf_unittest.TestAllTypes" does not have '
- 'extensions.'),
- text_format.Parse, text, message)
-
- def testParseGroupNotClosed(self):
- message = unittest_pb2.TestAllTypes()
+ def testParseGroupNotClosed(self, message_module):
+ message = message_module.TestAllTypes()
text = 'RepeatedGroup: <'
self.assertRaisesWithLiteralMatch(
text_format.ParseError, '1:16 : Expected ">".',
@@ -458,46 +321,46 @@ class TextFormatTest(basetest.TestCase):
text_format.ParseError, '1:16 : Expected "}".',
text_format.Parse, text, message)
- def testParseEmptyGroup(self):
- message = unittest_pb2.TestAllTypes()
+ def testParseEmptyGroup(self, message_module):
+ message = message_module.TestAllTypes()
text = 'OptionalGroup: {}'
text_format.Parse(text, message)
self.assertTrue(message.HasField('optionalgroup'))
message.Clear()
- message = unittest_pb2.TestAllTypes()
+ message = message_module.TestAllTypes()
text = 'OptionalGroup: <>'
text_format.Parse(text, message)
self.assertTrue(message.HasField('optionalgroup'))
- def testParseBadEnumValue(self):
- message = unittest_pb2.TestAllTypes()
+ def testParseBadEnumValue(self, message_module):
+ message = message_module.TestAllTypes()
text = 'optional_nested_enum: BARR'
- self.assertRaisesWithLiteralMatch(
+ self.assertRaisesRegexp(
text_format.ParseError,
- ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" '
- 'has no value named BARR.'),
+ (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" '
+ r'has no value named BARR.'),
text_format.Parse, text, message)
- message = unittest_pb2.TestAllTypes()
+ message = message_module.TestAllTypes()
text = 'optional_nested_enum: 100'
- self.assertRaisesWithLiteralMatch(
+ self.assertRaisesRegexp(
text_format.ParseError,
- ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" '
- 'has no value with number 100.'),
+ (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" '
+ r'has no value with number 100.'),
text_format.Parse, text, message)
- def testParseBadIntValue(self):
- message = unittest_pb2.TestAllTypes()
+ def testParseBadIntValue(self, message_module):
+ message = message_module.TestAllTypes()
text = 'optional_int32: bork'
self.assertRaisesWithLiteralMatch(
text_format.ParseError,
('1:17 : Couldn\'t parse integer: bork'),
text_format.Parse, text, message)
- def testParseStringFieldUnescape(self):
- message = unittest_pb2.TestAllTypes()
+ def testParseStringFieldUnescape(self, message_module):
+ message = message_module.TestAllTypes()
text = r'''repeated_string: "\xf\x62"
repeated_string: "\\xf\\x62"
repeated_string: "\\\xf\\\x62"
@@ -516,40 +379,205 @@ class TextFormatTest(basetest.TestCase):
message.repeated_string[4])
self.assertEqual(SLASH + 'x20', message.repeated_string[5])
- def testMergeDuplicateScalars(self):
- message = unittest_pb2.TestAllTypes()
+ def testMergeDuplicateScalars(self, message_module):
+ message = message_module.TestAllTypes()
text = ('optional_int32: 42 '
'optional_int32: 67')
r = text_format.Merge(text, message)
self.assertIs(r, message)
self.assertEqual(67, message.optional_int32)
- def testParseDuplicateScalars(self):
- message = unittest_pb2.TestAllTypes()
- text = ('optional_int32: 42 '
- 'optional_int32: 67')
- self.assertRaisesWithLiteralMatch(
- text_format.ParseError,
- ('1:36 : Message type "protobuf_unittest.TestAllTypes" should not '
- 'have multiple "optional_int32" fields.'),
- text_format.Parse, text, message)
-
- def testMergeDuplicateNestedMessageScalars(self):
- message = unittest_pb2.TestAllTypes()
+ def testMergeDuplicateNestedMessageScalars(self, message_module):
+ message = message_module.TestAllTypes()
text = ('optional_nested_message { bb: 1 } '
'optional_nested_message { bb: 2 }')
r = text_format.Merge(text, message)
self.assertTrue(r is message)
self.assertEqual(2, message.optional_nested_message.bb)
- def testParseDuplicateNestedMessageScalars(self):
+ def testParseOneof(self, message_module):
+ m = message_module.TestAllTypes()
+ m.oneof_uint32 = 11
+ m2 = message_module.TestAllTypes()
+ text_format.Parse(text_format.MessageToString(m), m2)
+ self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
+
+
+# These are tests that aren't fundamentally specific to proto2, but are at
+# the moment because of differences between the proto2 and proto3 test schemas.
+# Ideally the schemas would be made more similar so these tests could pass.
+class OnlyWorksWithProto2RightNowTests(TextFormatBase):
+
+ def testParseGolden(self):
+ golden_text = '\n'.join(self.ReadGolden('text_format_unittest_data.txt'))
+ parsed_message = unittest_pb2.TestAllTypes()
+ r = text_format.Parse(golden_text, parsed_message)
+ self.assertIs(r, parsed_message)
+
message = unittest_pb2.TestAllTypes()
- text = ('optional_nested_message { bb: 1 } '
- 'optional_nested_message { bb: 2 }')
+ test_util.SetAllFields(message)
+ self.assertEquals(message, parsed_message)
+
+ def testPrintAllFields(self):
+ message = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(message)
+ self.CompareToGoldenFile(
+ self.RemoveRedundantZeros(text_format.MessageToString(message)),
+ 'text_format_unittest_data_oneof_implemented.txt')
+
+ def testPrintAllFieldsPointy(self):
+ message = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(message)
+ self.CompareToGoldenFile(
+ self.RemoveRedundantZeros(
+ text_format.MessageToString(message, pointy_brackets=True)),
+ 'text_format_unittest_data_pointy_oneof.txt')
+
+ def testPrintInIndexOrder(self):
+ message = unittest_pb2.TestFieldOrderings()
+ message.my_string = '115'
+ message.my_int = 101
+ message.my_float = 111
+ message.optional_nested_message.oo = 0
+ message.optional_nested_message.bb = 1
+ self.CompareToGoldenText(
+ self.RemoveRedundantZeros(text_format.MessageToString(
+ message, use_index_order=True)),
+ 'my_string: \"115\"\nmy_int: 101\nmy_float: 111\n'
+ 'optional_nested_message {\n oo: 0\n bb: 1\n}\n')
+ self.CompareToGoldenText(
+ self.RemoveRedundantZeros(text_format.MessageToString(
+ message)),
+ 'my_int: 101\nmy_string: \"115\"\nmy_float: 111\n'
+ 'optional_nested_message {\n bb: 1\n oo: 0\n}\n')
+
+ def testMergeLinesGolden(self):
+ opened = self.ReadGolden('text_format_unittest_data.txt')
+ parsed_message = unittest_pb2.TestAllTypes()
+ r = text_format.MergeLines(opened, parsed_message)
+ self.assertIs(r, parsed_message)
+
+ message = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(message)
+ self.assertEqual(message, parsed_message)
+
+ def testParseLinesGolden(self):
+ opened = self.ReadGolden('text_format_unittest_data.txt')
+ parsed_message = unittest_pb2.TestAllTypes()
+ r = text_format.ParseLines(opened, parsed_message)
+ self.assertIs(r, parsed_message)
+
+ message = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(message)
+ self.assertEquals(message, parsed_message)
+
+
+# Tests of proto2-only features (MessageSet and extensions).
+class Proto2Tests(TextFormatBase):
+
+ def testPrintMessageSet(self):
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
+ ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
+ message.message_set.Extensions[ext1].i = 23
+ message.message_set.Extensions[ext2].str = 'foo'
+ self.CompareToGoldenText(
+ text_format.MessageToString(message),
+ 'message_set {\n'
+ ' [protobuf_unittest.TestMessageSetExtension1] {\n'
+ ' i: 23\n'
+ ' }\n'
+ ' [protobuf_unittest.TestMessageSetExtension2] {\n'
+ ' str: \"foo\"\n'
+ ' }\n'
+ '}\n')
+
+ def testPrintMessageSetAsOneLine(self):
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
+ ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
+ message.message_set.Extensions[ext1].i = 23
+ message.message_set.Extensions[ext2].str = 'foo'
+ self.CompareToGoldenText(
+ text_format.MessageToString(message, as_one_line=True),
+ 'message_set {'
+ ' [protobuf_unittest.TestMessageSetExtension1] {'
+ ' i: 23'
+ ' }'
+ ' [protobuf_unittest.TestMessageSetExtension2] {'
+ ' str: \"foo\"'
+ ' }'
+ ' }')
+
+ def testParseMessageSet(self):
+ message = unittest_pb2.TestAllTypes()
+ text = ('repeated_uint64: 1\n'
+ 'repeated_uint64: 2\n')
+ text_format.Parse(text, message)
+ self.assertEqual(1, message.repeated_uint64[0])
+ self.assertEqual(2, message.repeated_uint64[1])
+
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ text = ('message_set {\n'
+ ' [protobuf_unittest.TestMessageSetExtension1] {\n'
+ ' i: 23\n'
+ ' }\n'
+ ' [protobuf_unittest.TestMessageSetExtension2] {\n'
+ ' str: \"foo\"\n'
+ ' }\n'
+ '}\n')
+ text_format.Parse(text, message)
+ ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
+ ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
+ self.assertEquals(23, message.message_set.Extensions[ext1].i)
+ self.assertEquals('foo', message.message_set.Extensions[ext2].str)
+
+ def testPrintAllExtensions(self):
+ message = unittest_pb2.TestAllExtensions()
+ test_util.SetAllExtensions(message)
+ self.CompareToGoldenFile(
+ self.RemoveRedundantZeros(text_format.MessageToString(message)),
+ 'text_format_unittest_extensions_data.txt')
+
+ def testPrintAllExtensionsPointy(self):
+ message = unittest_pb2.TestAllExtensions()
+ test_util.SetAllExtensions(message)
+ self.CompareToGoldenFile(
+ self.RemoveRedundantZeros(text_format.MessageToString(
+ message, pointy_brackets=True)),
+ 'text_format_unittest_extensions_data_pointy.txt')
+
+ def testParseGoldenExtensions(self):
+ golden_text = '\n'.join(self.ReadGolden(
+ 'text_format_unittest_extensions_data.txt'))
+ parsed_message = unittest_pb2.TestAllExtensions()
+ text_format.Parse(golden_text, parsed_message)
+
+ message = unittest_pb2.TestAllExtensions()
+ test_util.SetAllExtensions(message)
+ self.assertEquals(message, parsed_message)
+
+ def testParseAllExtensions(self):
+ message = unittest_pb2.TestAllExtensions()
+ test_util.SetAllExtensions(message)
+ ascii_text = text_format.MessageToString(message)
+
+ parsed_message = unittest_pb2.TestAllExtensions()
+ text_format.Parse(ascii_text, parsed_message)
+ self.assertEqual(message, parsed_message)
+
+ def testParseBadExtension(self):
+ message = unittest_pb2.TestAllExtensions()
+ text = '[unknown_extension]: 8\n'
self.assertRaisesWithLiteralMatch(
text_format.ParseError,
- ('1:65 : Message type "protobuf_unittest.TestAllTypes.NestedMessage" '
- 'should not have multiple "bb" fields.'),
+ '1:2 : Extension "unknown_extension" not registered.',
+ text_format.Parse, text, message)
+ message = unittest_pb2.TestAllTypes()
+ self.assertRaisesWithLiteralMatch(
+ text_format.ParseError,
+ ('1:2 : Message type "protobuf_unittest.TestAllTypes" does not have '
+ 'extensions.'),
text_format.Parse, text, message)
def testMergeDuplicateExtensionScalars(self):
@@ -572,32 +600,25 @@ class TextFormatTest(basetest.TestCase):
'"protobuf_unittest.optional_int32_extension" extensions.'),
text_format.Parse, text, message)
- def testParseLinesGolden(self):
- opened = self.ReadGolden('text_format_unittest_data.txt')
- parsed_message = unittest_pb2.TestAllTypes()
- r = text_format.ParseLines(opened, parsed_message)
- self.assertIs(r, parsed_message)
-
+ def testParseDuplicateNestedMessageScalars(self):
message = unittest_pb2.TestAllTypes()
- test_util.SetAllFields(message)
- self.assertEquals(message, parsed_message)
-
- def testMergeLinesGolden(self):
- opened = self.ReadGolden('text_format_unittest_data.txt')
- parsed_message = unittest_pb2.TestAllTypes()
- r = text_format.MergeLines(opened, parsed_message)
- self.assertIs(r, parsed_message)
+ text = ('optional_nested_message { bb: 1 } '
+ 'optional_nested_message { bb: 2 }')
+ self.assertRaisesWithLiteralMatch(
+ text_format.ParseError,
+ ('1:65 : Message type "protobuf_unittest.TestAllTypes.NestedMessage" '
+ 'should not have multiple "bb" fields.'),
+ text_format.Parse, text, message)
+ def testParseDuplicateScalars(self):
message = unittest_pb2.TestAllTypes()
- test_util.SetAllFields(message)
- self.assertEqual(message, parsed_message)
-
- def testParseOneof(self):
- m = unittest_pb2.TestAllTypes()
- m.oneof_uint32 = 11
- m2 = unittest_pb2.TestAllTypes()
- text_format.Parse(text_format.MessageToString(m), m2)
- self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
+ text = ('optional_int32: 42 '
+ 'optional_int32: 67')
+ self.assertRaisesWithLiteralMatch(
+ text_format.ParseError,
+ ('1:36 : Message type "protobuf_unittest.TestAllTypes" should not '
+ 'have multiple "optional_int32" fields.'),
+ text_format.Parse, text, message)
class TokenizerTest(basetest.TestCase):
diff --git a/python/google/protobuf/internal/type_checkers.py b/python/google/protobuf/internal/type_checkers.py
index 118725da..76c056c4 100755
--- a/python/google/protobuf/internal/type_checkers.py
+++ b/python/google/protobuf/internal/type_checkers.py
@@ -59,6 +59,8 @@ from google.protobuf import descriptor
_FieldDescriptor = descriptor.FieldDescriptor
+def SupportsOpenEnums(field_descriptor):
+ return field_descriptor.containing_type.syntax == "proto3"
def GetTypeChecker(field):
"""Returns a type checker for a message field of the specified types.
@@ -74,7 +76,11 @@ def GetTypeChecker(field):
field.type == _FieldDescriptor.TYPE_STRING):
return UnicodeValueChecker()
if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
- return EnumValueChecker(field.enum_type)
+ if SupportsOpenEnums(field):
+ # When open enums are supported, any int32 can be assigned.
+ return _VALUE_CHECKERS[_FieldDescriptor.CPPTYPE_INT32]
+ else:
+ return EnumValueChecker(field.enum_type)
return _VALUE_CHECKERS[field.cpp_type]
diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py
index a4dc1f7c..59f9ae4c 100755
--- a/python/google/protobuf/internal/unknown_fields_test.py
+++ b/python/google/protobuf/internal/unknown_fields_test.py
@@ -38,6 +38,7 @@ __author__ = 'bohdank@google.com (Bohdan Koval)'
from google.apputils import basetest
from google.protobuf import unittest_mset_pb2
from google.protobuf import unittest_pb2
+from google.protobuf import unittest_proto3_arena_pb2
from google.protobuf.internal import api_implementation
from google.protobuf.internal import encoder
from google.protobuf.internal import missing_enum_values_pb2
@@ -45,10 +46,81 @@ from google.protobuf.internal import test_util
from google.protobuf.internal import type_checkers
+class UnknownFieldsTest(basetest.TestCase):
+
+ def setUp(self):
+ self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
+ self.all_fields = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(self.all_fields)
+ self.all_fields_data = self.all_fields.SerializeToString()
+ self.empty_message = unittest_pb2.TestEmptyMessage()
+ self.empty_message.ParseFromString(self.all_fields_data)
+
+ def testSerialize(self):
+ data = self.empty_message.SerializeToString()
+
+ # Don't use assertEqual because we don't want to dump raw binary data to
+ # stdout.
+ self.assertTrue(data == self.all_fields_data)
+
+ def testSerializeProto3(self):
+ # Verify that proto3 doesn't preserve unknown fields.
+ message = unittest_proto3_arena_pb2.TestEmptyMessage()
+ message.ParseFromString(self.all_fields_data)
+ self.assertEqual(0, len(message.SerializeToString()))
+
+ def testByteSize(self):
+ self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize())
+
+ def testListFields(self):
+ # Make sure ListFields doesn't return unknown fields.
+ self.assertEqual(0, len(self.empty_message.ListFields()))
+
+ def testSerializeMessageSetWireFormatUnknownExtension(self):
+ # Create a message using the message set wire format with an unknown
+ # message.
+ raw = unittest_mset_pb2.RawMessageSet()
+
+ # Add an unknown extension.
+ item = raw.item.add()
+ item.type_id = 1545009
+ message1 = unittest_mset_pb2.TestMessageSetExtension1()
+ message1.i = 12345
+ item.message = message1.SerializeToString()
+
+ serialized = raw.SerializeToString()
+
+ # Parse message using the message set wire format.
+ proto = unittest_mset_pb2.TestMessageSet()
+ proto.MergeFromString(serialized)
+
+ # Verify that the unknown extension is serialized unchanged
+ reserialized = proto.SerializeToString()
+ new_raw = unittest_mset_pb2.RawMessageSet()
+ new_raw.MergeFromString(reserialized)
+ self.assertEqual(raw, new_raw)
+
+ # C++ implementation for proto2 does not currently take into account unknown
+ # fields when checking equality.
+ #
+ # TODO(haberman): fix this.
+ @basetest.unittest.skipIf(
+ api_implementation.Type() == 'cpp' and api_implementation.Version() == 2,
+ 'C++ implementation does not expose unknown fields to Python')
+ def testEquals(self):
+ message = unittest_pb2.TestEmptyMessage()
+ message.ParseFromString(self.all_fields_data)
+ self.assertEqual(self.empty_message, message)
+
+ self.all_fields.ClearField('optional_string')
+ message.ParseFromString(self.all_fields.SerializeToString())
+ self.assertNotEqual(self.empty_message, message)
+
+
@basetest.unittest.skipIf(
api_implementation.Type() == 'cpp' and api_implementation.Version() == 2,
'C++ implementation does not expose unknown fields to Python')
-class UnknownFieldsTest(basetest.TestCase):
+class UnknownFieldsAccessorsTest(basetest.TestCase):
def setUp(self):
self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
@@ -98,13 +170,6 @@ class UnknownFieldsTest(basetest.TestCase):
value = self.GetField('optionalgroup')
self.assertEqual(self.all_fields.optionalgroup, value)
- def testSerialize(self):
- data = self.empty_message.SerializeToString()
-
- # Don't use assertEqual because we don't want to dump raw binary data to
- # stdout.
- self.assertTrue(data == self.all_fields_data)
-
def testCopyFrom(self):
message = unittest_pb2.TestEmptyMessage()
message.CopyFrom(self.empty_message)
@@ -132,51 +197,12 @@ class UnknownFieldsTest(basetest.TestCase):
self.empty_message.Clear()
self.assertEqual(0, len(self.empty_message._unknown_fields))
- def testByteSize(self):
- self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize())
-
def testUnknownExtensions(self):
message = unittest_pb2.TestEmptyMessageWithExtensions()
message.ParseFromString(self.all_fields_data)
self.assertEqual(self.empty_message._unknown_fields,
message._unknown_fields)
- def testListFields(self):
- # Make sure ListFields doesn't return unknown fields.
- self.assertEqual(0, len(self.empty_message.ListFields()))
-
- def testSerializeMessageSetWireFormatUnknownExtension(self):
- # Create a message using the message set wire format with an unknown
- # message.
- raw = unittest_mset_pb2.RawMessageSet()
-
- # Add an unknown extension.
- item = raw.item.add()
- item.type_id = 1545009
- message1 = unittest_mset_pb2.TestMessageSetExtension1()
- message1.i = 12345
- item.message = message1.SerializeToString()
-
- serialized = raw.SerializeToString()
-
- # Parse message using the message set wire format.
- proto = unittest_mset_pb2.TestMessageSet()
- proto.MergeFromString(serialized)
-
- # Verify that the unknown extension is serialized unchanged
- reserialized = proto.SerializeToString()
- new_raw = unittest_mset_pb2.RawMessageSet()
- new_raw.MergeFromString(reserialized)
- self.assertEqual(raw, new_raw)
-
- def testEquals(self):
- message = unittest_pb2.TestEmptyMessage()
- message.ParseFromString(self.all_fields_data)
- self.assertEqual(self.empty_message, message)
-
- self.all_fields.ClearField('optional_string')
- message.ParseFromString(self.all_fields.SerializeToString())
- self.assertNotEqual(self.empty_message, message)
@basetest.unittest.skipIf(