aboutsummaryrefslogtreecommitdiffhomepage
path: root/python/google/protobuf/internal
diff options
context:
space:
mode:
authorGravatar Dan O'Reilly <oreilldf@gmail.com>2015-08-12 23:57:46 -0400
committerGravatar Dan O'Reilly <oreilldf@gmail.com>2015-08-12 23:57:46 -0400
commite47cdd5a559f488ba52756927ce68f4cf93874fa (patch)
tree8ce2723e822808baf58e96f569c86035717ea351 /python/google/protobuf/internal
parentdaeaa6a28b81195f24d89222e649d79c9555af8b (diff)
parent38a56ee4b19d72c2e9d81a08b018704d1addf561 (diff)
Merge remote-tracking branch 'upstream/master' into py2_py3_straddle
Conflicts: python/google/protobuf/descriptor_pool.py python/google/protobuf/internal/api_implementation_default_test.py python/google/protobuf/internal/cpp_message.py python/google/protobuf/internal/descriptor_database_test.py python/google/protobuf/internal/descriptor_pool_test.py python/google/protobuf/internal/descriptor_python_test.py python/google/protobuf/internal/descriptor_test.py python/google/protobuf/internal/generator_test.py python/google/protobuf/internal/message_factory_python_test.py python/google/protobuf/internal/message_factory_test.py python/google/protobuf/internal/message_test.py python/google/protobuf/internal/proto_builder_test.py python/google/protobuf/internal/python_message.py python/google/protobuf/internal/reflection_test.py python/google/protobuf/internal/service_reflection_test.py python/google/protobuf/internal/symbol_database_test.py python/google/protobuf/internal/text_encoding_test.py python/google/protobuf/internal/text_format_test.py python/google/protobuf/internal/unknown_fields_test.py python/google/protobuf/internal/wire_format_test.py python/google/protobuf/pyext/descriptor_cpp2_test.py python/google/protobuf/pyext/message_factory_cpp2_test.py python/google/protobuf/pyext/reflection_cpp2_generated_test.py python/setup.py ruby/lib/google/protobuf/message_exts.rb
Diffstat (limited to 'python/google/protobuf/internal')
-rwxr-xr-xpython/google/protobuf/internal/_parameterized.py436
-rw-r--r--python/google/protobuf/internal/api_implementation.cc14
-rwxr-xr-xpython/google/protobuf/internal/api_implementation.py42
-rw-r--r--python/google/protobuf/internal/api_implementation_default_test.py63
-rwxr-xr-xpython/google/protobuf/internal/containers.py332
-rwxr-xr-xpython/google/protobuf/internal/cpp_message.py667
-rwxr-xr-xpython/google/protobuf/internal/decoder.py47
-rw-r--r--python/google/protobuf/internal/descriptor_database_test.py3
-rw-r--r--python/google/protobuf/internal/descriptor_pool_test.py10
-rw-r--r--python/google/protobuf/internal/descriptor_python_test.py54
-rwxr-xr-xpython/google/protobuf/internal/descriptor_test.py177
-rwxr-xr-xpython/google/protobuf/internal/encoder.py55
-rwxr-xr-xpython/google/protobuf/internal/generator_test.py7
-rw-r--r--python/google/protobuf/internal/import_test_package/BUILD27
-rw-r--r--python/google/protobuf/internal/message_factory_python_test.py54
-rw-r--r--python/google/protobuf/internal/message_factory_test.py3
-rw-r--r--python/google/protobuf/internal/message_python_test.py54
-rwxr-xr-xpython/google/protobuf/internal/message_test.py1259
-rw-r--r--python/google/protobuf/internal/proto_builder_test.py23
-rwxr-xr-xpython/google/protobuf/internal/python_message.py235
-rwxr-xr-xpython/google/protobuf/internal/reflection_test.py5
-rwxr-xr-xpython/google/protobuf/internal/service_reflection_test.py3
-rw-r--r--python/google/protobuf/internal/symbol_database_test.py3
-rwxr-xr-xpython/google/protobuf/internal/test_util.py201
-rwxr-xr-xpython/google/protobuf/internal/text_encoding_test.py3
-rwxr-xr-xpython/google/protobuf/internal/text_format_test.py692
-rwxr-xr-xpython/google/protobuf/internal/type_checkers.py17
-rwxr-xr-xpython/google/protobuf/internal/unknown_fields_test.py122
-rwxr-xr-xpython/google/protobuf/internal/wire_format_test.py3
29 files changed, 2948 insertions, 1663 deletions
diff --git a/python/google/protobuf/internal/_parameterized.py b/python/google/protobuf/internal/_parameterized.py
new file mode 100755
index 00000000..400b2216
--- /dev/null
+++ b/python/google/protobuf/internal/_parameterized.py
@@ -0,0 +1,436 @@
+#! /usr/bin/env python
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# https://developers.google.com/protocol-buffers/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Adds support for parameterized tests to Python's unittest TestCase class.
+
+A parameterized test is a method in a test case that is invoked with different
+argument tuples.
+
+A simple example:
+
+ class AdditionExample(parameterized.ParameterizedTestCase):
+ @parameterized.Parameters(
+ (1, 2, 3),
+ (4, 5, 9),
+ (1, 1, 3))
+ def testAddition(self, op1, op2, result):
+ self.assertEquals(result, op1 + op2)
+
+
+Each invocation is a separate test case and properly isolated just
+like a normal test method, with its own setUp/tearDown cycle. In the
+example above, there are three separate testcases, one of which will
+fail due to an assertion error (1 + 1 != 3).
+
+Parameters for invididual test cases can be tuples (with positional parameters)
+or dictionaries (with named parameters):
+
+ class AdditionExample(parameterized.ParameterizedTestCase):
+ @parameterized.Parameters(
+ {'op1': 1, 'op2': 2, 'result': 3},
+ {'op1': 4, 'op2': 5, 'result': 9},
+ )
+ def testAddition(self, op1, op2, result):
+ self.assertEquals(result, op1 + op2)
+
+If a parameterized test fails, the error message will show the
+original test name (which is modified internally) and the arguments
+for the specific invocation, which are part of the string returned by
+the shortDescription() method on test cases.
+
+The id method of the test, used internally by the unittest framework,
+is also modified to show the arguments. To make sure that test names
+stay the same across several invocations, object representations like
+
+ >>> class Foo(object):
+ ... pass
+ >>> repr(Foo())
+ '<__main__.Foo object at 0x23d8610>'
+
+are turned into '<__main__.Foo>'. For even more descriptive names,
+especially in test logs, you can use the NamedParameters decorator. In
+this case, only tuples are supported, and the first parameters has to
+be a string (or an object that returns an apt name when converted via
+str()):
+
+ class NamedExample(parameterized.ParameterizedTestCase):
+ @parameterized.NamedParameters(
+ ('Normal', 'aa', 'aaa', True),
+ ('EmptyPrefix', '', 'abc', True),
+ ('BothEmpty', '', '', True))
+ def testStartsWith(self, prefix, string, result):
+ self.assertEquals(result, strings.startswith(prefix))
+
+Named tests also have the benefit that they can be run individually
+from the command line:
+
+ $ testmodule.py NamedExample.testStartsWithNormal
+ .
+ --------------------------------------------------------------------
+ Ran 1 test in 0.000s
+
+ OK
+
+Parameterized Classes
+=====================
+If invocation arguments are shared across test methods in a single
+ParameterizedTestCase class, instead of decorating all test methods
+individually, the class itself can be decorated:
+
+ @parameterized.Parameters(
+ (1, 2, 3)
+ (4, 5, 9))
+ class ArithmeticTest(parameterized.ParameterizedTestCase):
+ def testAdd(self, arg1, arg2, result):
+ self.assertEqual(arg1 + arg2, result)
+
+ def testSubtract(self, arg2, arg2, result):
+ self.assertEqual(result - arg1, arg2)
+
+Inputs from Iterables
+=====================
+If parameters should be shared across several test cases, or are dynamically
+created from other sources, a single non-tuple iterable can be passed into
+the decorator. This iterable will be used to obtain the test cases:
+
+ class AdditionExample(parameterized.ParameterizedTestCase):
+ @parameterized.Parameters(
+ c.op1, c.op2, c.result for c in testcases
+ )
+ def testAddition(self, op1, op2, result):
+ self.assertEquals(result, op1 + op2)
+
+
+Single-Argument Test Methods
+============================
+If a test method takes only one argument, the single argument does not need to
+be wrapped into a tuple:
+
+ class NegativeNumberExample(parameterized.ParameterizedTestCase):
+ @parameterized.Parameters(
+ -1, -3, -4, -5
+ )
+ def testIsNegative(self, arg):
+ self.assertTrue(IsNegative(arg))
+"""
+
+__author__ = 'tmarek@google.com (Torsten Marek)'
+
+import collections
+import functools
+import re
+import types
+import unittest
+import uuid
+
+ADDR_RE = re.compile(r'\<([a-zA-Z0-9_\-\.]+) object at 0x[a-fA-F0-9]+\>')
+_SEPARATOR = uuid.uuid1().hex
+_FIRST_ARG = object()
+_ARGUMENT_REPR = object()
+
+
+def _CleanRepr(obj):
+ return ADDR_RE.sub(r'<\1>', repr(obj))
+
+
+# Helper function formerly from the unittest module, removed from it in
+# Python 2.7.
+def _StrClass(cls):
+ return '%s.%s' % (cls.__module__, cls.__name__)
+
+
+def _NonStringIterable(obj):
+ return (isinstance(obj, collections.Iterable) and not
+ isinstance(obj, basestring))
+
+
+def _FormatParameterList(testcase_params):
+ if isinstance(testcase_params, collections.Mapping):
+ return ', '.join('%s=%s' % (argname, _CleanRepr(value))
+ for argname, value in testcase_params.iteritems())
+ elif _NonStringIterable(testcase_params):
+ return ', '.join(map(_CleanRepr, testcase_params))
+ else:
+ return _FormatParameterList((testcase_params,))
+
+
+class _ParameterizedTestIter(object):
+ """Callable and iterable class for producing new test cases."""
+
+ def __init__(self, test_method, testcases, naming_type):
+ """Returns concrete test functions for a test and a list of parameters.
+
+ The naming_type is used to determine the name of the concrete
+ functions as reported by the unittest framework. If naming_type is
+ _FIRST_ARG, the testcases must be tuples, and the first element must
+ have a string representation that is a valid Python identifier.
+
+ Args:
+ test_method: The decorated test method.
+ testcases: (list of tuple/dict) A list of parameter
+ tuples/dicts for individual test invocations.
+ naming_type: The test naming type, either _NAMED or _ARGUMENT_REPR.
+ """
+ self._test_method = test_method
+ self.testcases = testcases
+ self._naming_type = naming_type
+
+ def __call__(self, *args, **kwargs):
+ raise RuntimeError('You appear to be running a parameterized test case '
+ 'without having inherited from parameterized.'
+ 'ParameterizedTestCase. This is bad because none of '
+ 'your test cases are actually being run.')
+
+ def __iter__(self):
+ test_method = self._test_method
+ naming_type = self._naming_type
+
+ def MakeBoundParamTest(testcase_params):
+ @functools.wraps(test_method)
+ def BoundParamTest(self):
+ if isinstance(testcase_params, collections.Mapping):
+ test_method(self, **testcase_params)
+ elif _NonStringIterable(testcase_params):
+ test_method(self, *testcase_params)
+ else:
+ test_method(self, testcase_params)
+
+ if naming_type is _FIRST_ARG:
+ # Signal the metaclass that the name of the test function is unique
+ # and descriptive.
+ BoundParamTest.__x_use_name__ = True
+ BoundParamTest.__name__ += str(testcase_params[0])
+ testcase_params = testcase_params[1:]
+ elif naming_type is _ARGUMENT_REPR:
+ # __x_extra_id__ is used to pass naming information to the __new__
+ # method of TestGeneratorMetaclass.
+ # The metaclass will make sure to create a unique, but nondescriptive
+ # name for this test.
+ BoundParamTest.__x_extra_id__ = '(%s)' % (
+ _FormatParameterList(testcase_params),)
+ else:
+ raise RuntimeError('%s is not a valid naming type.' % (naming_type,))
+
+ BoundParamTest.__doc__ = '%s(%s)' % (
+ BoundParamTest.__name__, _FormatParameterList(testcase_params))
+ if test_method.__doc__:
+ BoundParamTest.__doc__ += '\n%s' % (test_method.__doc__,)
+ return BoundParamTest
+ return (MakeBoundParamTest(c) for c in self.testcases)
+
+
+def _IsSingletonList(testcases):
+ """True iff testcases contains only a single non-tuple element."""
+ return len(testcases) == 1 and not isinstance(testcases[0], tuple)
+
+
+def _ModifyClass(class_object, testcases, naming_type):
+ assert not getattr(class_object, '_id_suffix', None), (
+ 'Cannot add parameters to %s,'
+ ' which already has parameterized methods.' % (class_object,))
+ class_object._id_suffix = id_suffix = {}
+ for name, obj in class_object.__dict__.items():
+ if (name.startswith(unittest.TestLoader.testMethodPrefix)
+ and isinstance(obj, types.FunctionType)):
+ delattr(class_object, name)
+ methods = {}
+ _UpdateClassDictForParamTestCase(
+ methods, id_suffix, name,
+ _ParameterizedTestIter(obj, testcases, naming_type))
+ for name, meth in methods.iteritems():
+ setattr(class_object, name, meth)
+
+
+def _ParameterDecorator(naming_type, testcases):
+ """Implementation of the parameterization decorators.
+
+ Args:
+ naming_type: The naming type.
+ testcases: Testcase parameters.
+
+ Returns:
+ A function for modifying the decorated object.
+ """
+ def _Apply(obj):
+ if isinstance(obj, type):
+ _ModifyClass(
+ obj,
+ list(testcases) if not isinstance(testcases, collections.Sequence)
+ else testcases,
+ naming_type)
+ return obj
+ else:
+ return _ParameterizedTestIter(obj, testcases, naming_type)
+
+ if _IsSingletonList(testcases):
+ assert _NonStringIterable(testcases[0]), (
+ 'Single parameter argument must be a non-string iterable')
+ testcases = testcases[0]
+
+ return _Apply
+
+
+def Parameters(*testcases):
+ """A decorator for creating parameterized tests.
+
+ See the module docstring for a usage example.
+ Args:
+ *testcases: Parameters for the decorated method, either a single
+ iterable, or a list of tuples/dicts/objects (for tests
+ with only one argument).
+
+ Returns:
+ A test generator to be handled by TestGeneratorMetaclass.
+ """
+ return _ParameterDecorator(_ARGUMENT_REPR, testcases)
+
+
+def NamedParameters(*testcases):
+ """A decorator for creating parameterized tests.
+
+ See the module docstring for a usage example. The first element of
+ each parameter tuple should be a string and will be appended to the
+ name of the test method.
+
+ Args:
+ *testcases: Parameters for the decorated method, either a single
+ iterable, or a list of tuples.
+
+ Returns:
+ A test generator to be handled by TestGeneratorMetaclass.
+ """
+ return _ParameterDecorator(_FIRST_ARG, testcases)
+
+
+class TestGeneratorMetaclass(type):
+ """Metaclass for test cases with test generators.
+
+ A test generator is an iterable in a testcase that produces callables. These
+ callables must be single-argument methods. These methods are injected into
+ the class namespace and the original iterable is removed. If the name of the
+ iterable conforms to the test pattern, the injected methods will be picked
+ up as tests by the unittest framework.
+
+ In general, it is supposed to be used in conjuction with the
+ Parameters decorator.
+ """
+
+ def __new__(mcs, class_name, bases, dct):
+ dct['_id_suffix'] = id_suffix = {}
+ for name, obj in dct.items():
+ if (name.startswith(unittest.TestLoader.testMethodPrefix) and
+ _NonStringIterable(obj)):
+ iterator = iter(obj)
+ dct.pop(name)
+ _UpdateClassDictForParamTestCase(dct, id_suffix, name, iterator)
+
+ return type.__new__(mcs, class_name, bases, dct)
+
+
+def _UpdateClassDictForParamTestCase(dct, id_suffix, name, iterator):
+ """Adds individual test cases to a dictionary.
+
+ Args:
+ dct: The target dictionary.
+ id_suffix: The dictionary for mapping names to test IDs.
+ name: The original name of the test case.
+ iterator: The iterator generating the individual test cases.
+ """
+ for idx, func in enumerate(iterator):
+ assert callable(func), 'Test generators must yield callables, got %r' % (
+ func,)
+ if getattr(func, '__x_use_name__', False):
+ new_name = func.__name__
+ else:
+ new_name = '%s%s%d' % (name, _SEPARATOR, idx)
+ assert new_name not in dct, (
+ 'Name of parameterized test case "%s" not unique' % (new_name,))
+ dct[new_name] = func
+ id_suffix[new_name] = getattr(func, '__x_extra_id__', '')
+
+
+class ParameterizedTestCase(unittest.TestCase):
+ """Base class for test cases using the Parameters decorator."""
+ __metaclass__ = TestGeneratorMetaclass
+
+ def _OriginalName(self):
+ return self._testMethodName.split(_SEPARATOR)[0]
+
+ def __str__(self):
+ return '%s (%s)' % (self._OriginalName(), _StrClass(self.__class__))
+
+ def id(self): # pylint: disable=invalid-name
+ """Returns the descriptive ID of the test.
+
+ This is used internally by the unittesting framework to get a name
+ for the test to be used in reports.
+
+ Returns:
+ The test id.
+ """
+ return '%s.%s%s' % (_StrClass(self.__class__),
+ self._OriginalName(),
+ self._id_suffix.get(self._testMethodName, ''))
+
+
+def CoopParameterizedTestCase(other_base_class):
+ """Returns a new base class with a cooperative metaclass base.
+
+ This enables the ParameterizedTestCase to be used in combination
+ with other base classes that have custom metaclasses, such as
+ mox.MoxTestBase.
+
+ Only works with metaclasses that do not override type.__new__.
+
+ Example:
+
+ import google3
+ import mox
+
+ from google3.testing.pybase import parameterized
+
+ class ExampleTest(parameterized.CoopParameterizedTestCase(mox.MoxTestBase)):
+ ...
+
+ Args:
+ other_base_class: (class) A test case base class.
+
+ Returns:
+ A new class object.
+ """
+ metaclass = type(
+ 'CoopMetaclass',
+ (other_base_class.__metaclass__,
+ TestGeneratorMetaclass), {})
+ return metaclass(
+ 'CoopParameterizedTestCase',
+ (other_base_class, ParameterizedTestCase), {})
diff --git a/python/google/protobuf/internal/api_implementation.cc b/python/google/protobuf/internal/api_implementation.cc
index 83db40b1..6db12e8d 100644
--- a/python/google/protobuf/internal/api_implementation.cc
+++ b/python/google/protobuf/internal/api_implementation.cc
@@ -50,10 +50,7 @@ namespace python {
// and
// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION=2
#ifdef PYTHON_PROTO2_CPP_IMPL_V1
-#if PY_MAJOR_VERSION >= 3
-#error "PYTHON_PROTO2_CPP_IMPL_V1 is not supported under Python 3."
-#endif
-static int kImplVersion = 1;
+#error "PYTHON_PROTO2_CPP_IMPL_V1 is no longer supported."
#else
#ifdef PYTHON_PROTO2_CPP_IMPL_V2
static int kImplVersion = 2;
@@ -62,14 +59,7 @@ static int kImplVersion = 2;
static int kImplVersion = 0;
#else
-// The defaults are set here. Python 3 uses the fast C++ APIv2 by default.
-// Python 2 still uses the Python version by default until some compatibility
-// issues can be worked around.
-#if PY_MAJOR_VERSION >= 3
-static int kImplVersion = 2;
-#else
-static int kImplVersion = 0;
-#endif
+static int kImplVersion = -1; // -1 means "Unspecified by compiler flags".
#endif // PYTHON_PROTO2_PYTHON_IMPL
#endif // PYTHON_PROTO2_CPP_IMPL_V2
diff --git a/python/google/protobuf/internal/api_implementation.py b/python/google/protobuf/internal/api_implementation.py
index f7926c16..ffcf7511 100755
--- a/python/google/protobuf/internal/api_implementation.py
+++ b/python/google/protobuf/internal/api_implementation.py
@@ -40,14 +40,33 @@ try:
# The compile-time constants in the _api_implementation module can be used to
# switch to a certain implementation of the Python API at build time.
_api_version = _api_implementation.api_version
- del _api_implementation
+ _proto_extension_modules_exist_in_build = True
except ImportError:
- _api_version = 0
+ _api_version = -1 # Unspecified by compiler flags.
+ _proto_extension_modules_exist_in_build = False
+
+if _api_version == 1:
+ raise ValueError('api_version=1 is no longer supported.')
+if _api_version < 0: # Still unspecified?
+ try:
+ # The presence of this module in a build allows the proto implementation to
+ # be upgraded merely via build deps rather than a compiler flag or the
+ # runtime environment variable.
+ # pylint: disable=g-import-not-at-top
+ from google.protobuf import _use_fast_cpp_protos
+ # Work around a known issue in the classic bootstrap .par import hook.
+ if not _use_fast_cpp_protos:
+ raise ImportError('_use_fast_cpp_protos import succeeded but was None')
+ del _use_fast_cpp_protos
+ _api_version = 2
+ except ImportError:
+ if _proto_extension_modules_exist_in_build:
+ if sys.version_info[0] >= 3: # Python 3 defaults to C++ impl v2.
+ _api_version = 2
+ # TODO(b/17427486): Make Python 2 default to C++ impl v2.
_default_implementation_type = (
- 'python' if _api_version == 0 else 'cpp')
-_default_version_str = (
- '1' if _api_version <= 1 else '2')
+ 'python' if _api_version <= 0 else 'cpp')
# This environment variable can be used to switch to a certain implementation
# of the Python API, overriding the compile-time constants in the
@@ -61,16 +80,15 @@ if _implementation_type != 'python':
# This environment variable can be used to switch between the two
# 'cpp' implementations, overriding the compile-time constants in the
-# _api_implementation module. Right now only 1 and 2 are valid values. Any other
-# value will be ignored.
+# _api_implementation module. Right now only '2' is supported. Any other
+# value will cause an error to be raised.
_implementation_version_str = os.getenv(
- 'PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION',
- _default_version_str)
+ 'PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION', '2')
-if _implementation_version_str not in ('1', '2'):
+if _implementation_version_str != '2':
raise ValueError(
- "unsupported PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION: '" +
- _implementation_version_str + "' (supported versions: 1, 2)"
+ 'unsupported PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION: "' +
+ _implementation_version_str + '" (supported versions: 2)'
)
_implementation_version = int(_implementation_version_str)
diff --git a/python/google/protobuf/internal/api_implementation_default_test.py b/python/google/protobuf/internal/api_implementation_default_test.py
deleted file mode 100644
index cb29e443..00000000
--- a/python/google/protobuf/internal/api_implementation_default_test.py
+++ /dev/null
@@ -1,63 +0,0 @@
-#! /usr/bin/python
-#
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# https://developers.google.com/protocol-buffers/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Test that the api_implementation defaults are what we expect."""
-
-import os
-import sys
-import unittest
-# Clear environment implementation settings before the google3 imports.
-os.environ.pop('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION', None)
-os.environ.pop('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION', None)
-
-# pylint: disable=g-import-not-at-top
-from google.protobuf.internal import api_implementation
-
-
-class ApiImplementationDefaultTest(unittest.TestCase):
-
- if sys.version_info.major <= 2:
-
- def testThatPythonIsTheDefault(self):
- """If -DPYTHON_PROTO_*IMPL* was given at build time, this may fail."""
- self.assertEqual('python', api_implementation.Type())
-
- else:
-
- def testThatCppApiV2IsTheDefault(self):
- """If -DPYTHON_PROTO_*IMPL* was given at build time, this may fail."""
- self.assertEqual('cpp', api_implementation.Type())
- self.assertEqual(2, api_implementation.Version())
-
-
-if __name__ == '__main__':
- unittest.main()
diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py
index 20bfa857..72c2fa01 100755
--- a/python/google/protobuf/internal/containers.py
+++ b/python/google/protobuf/internal/containers.py
@@ -41,6 +41,145 @@ are:
__author__ = 'petar@google.com (Petar Petrov)'
+import sys
+
+if sys.version_info[0] < 3:
+ # We would use collections.MutableMapping all the time, but in Python 2 it
+ # doesn't define __slots__. This causes two significant problems:
+ #
+ # 1. we can't disallow arbitrary attribute assignment, even if our derived
+ # classes *do* define __slots__.
+ #
+ # 2. we can't safely derive a C type from it without __slots__ defined (the
+ # interpreter expects to find a dict at tp_dictoffset, which we can't
+ # robustly provide. And we don't want an instance dict anyway.
+ #
+ # So this is the Python 2.7 definition of Mapping/MutableMapping functions
+ # verbatim, except that:
+ # 1. We declare __slots__.
+ # 2. We don't declare this as a virtual base class. The classes defined
+ # in collections are the interesting base classes, not us.
+ #
+ # Note: deriving from object is critical. It is the only thing that makes
+ # this a true type, allowing us to derive from it in C++ cleanly and making
+ # __slots__ properly disallow arbitrary element assignment.
+ from collections import Mapping as _Mapping
+
+ class Mapping(object):
+ __slots__ = ()
+
+ def get(self, key, default=None):
+ try:
+ return self[key]
+ except KeyError:
+ return default
+
+ def __contains__(self, key):
+ try:
+ self[key]
+ except KeyError:
+ return False
+ else:
+ return True
+
+ def iterkeys(self):
+ return iter(self)
+
+ def itervalues(self):
+ for key in self:
+ yield self[key]
+
+ def iteritems(self):
+ for key in self:
+ yield (key, self[key])
+
+ def keys(self):
+ return list(self)
+
+ def items(self):
+ return [(key, self[key]) for key in self]
+
+ def values(self):
+ return [self[key] for key in self]
+
+ # Mappings are not hashable by default, but subclasses can change this
+ __hash__ = None
+
+ def __eq__(self, other):
+ if not isinstance(other, _Mapping):
+ return NotImplemented
+ return dict(self.items()) == dict(other.items())
+
+ def __ne__(self, other):
+ return not (self == other)
+
+ class MutableMapping(Mapping):
+ __slots__ = ()
+
+ __marker = object()
+
+ def pop(self, key, default=__marker):
+ try:
+ value = self[key]
+ except KeyError:
+ if default is self.__marker:
+ raise
+ return default
+ else:
+ del self[key]
+ return value
+
+ def popitem(self):
+ try:
+ key = next(iter(self))
+ except StopIteration:
+ raise KeyError
+ value = self[key]
+ del self[key]
+ return key, value
+
+ def clear(self):
+ try:
+ while True:
+ self.popitem()
+ except KeyError:
+ pass
+
+ def update(*args, **kwds):
+ if len(args) > 2:
+ raise TypeError("update() takes at most 2 positional "
+ "arguments ({} given)".format(len(args)))
+ elif not args:
+ raise TypeError("update() takes at least 1 argument (0 given)")
+ self = args[0]
+ other = args[1] if len(args) >= 2 else ()
+
+ if isinstance(other, Mapping):
+ for key in other:
+ self[key] = other[key]
+ elif hasattr(other, "keys"):
+ for key in other.keys():
+ self[key] = other[key]
+ else:
+ for key, value in other:
+ self[key] = value
+ for key, value in kwds.items():
+ self[key] = value
+
+ def setdefault(self, key, default=None):
+ try:
+ return self[key]
+ except KeyError:
+ self[key] = default
+ return default
+
+ _Mapping.register(Mapping)
+
+else:
+ # In Python 3 we can just use MutableMapping directly, because it defines
+ # __slots__.
+ from collections import MutableMapping
+
class BaseContainer(object):
@@ -119,15 +258,23 @@ class RepeatedScalarFieldContainer(BaseContainer):
self._message_listener.Modified()
def extend(self, elem_seq):
- """Extends by appending the given sequence. Similar to list.extend()."""
- if not elem_seq:
- return
+ """Extends by appending the given iterable. Similar to list.extend()."""
- new_values = []
- for elem in elem_seq:
- new_values.append(self._type_checker.CheckValue(elem))
- self._values.extend(new_values)
- self._message_listener.Modified()
+ if elem_seq is None:
+ return
+ try:
+ elem_seq_iter = iter(elem_seq)
+ except TypeError:
+ if not elem_seq:
+ # silently ignore falsy inputs :-/.
+ # TODO(ptucker): Deprecate this behavior. b/18413862
+ return
+ raise
+
+ new_values = [self._type_checker.CheckValue(elem) for elem in elem_seq_iter]
+ if new_values:
+ self._values.extend(new_values)
+ self._message_listener.Modified()
def MergeFrom(self, other):
"""Appends the contents of another repeated field of the same type to this
@@ -141,6 +288,12 @@ class RepeatedScalarFieldContainer(BaseContainer):
self._values.remove(elem)
self._message_listener.Modified()
+ def pop(self, key=-1):
+ """Removes and returns an item at a given index. Similar to list.pop()."""
+ value = self._values[key]
+ self.__delitem__(key)
+ return value
+
def __setitem__(self, key, value):
"""Sets the item on the specified position."""
if isinstance(key, slice): # PY3
@@ -245,6 +398,12 @@ class RepeatedCompositeFieldContainer(BaseContainer):
self._values.remove(elem)
self._message_listener.Modified()
+ def pop(self, key=-1):
+ """Removes and returns an item at a given index. Similar to list.pop()."""
+ value = self._values[key]
+ self.__delitem__(key)
+ return value
+
def __getslice__(self, start, stop):
"""Retrieves the subset of items from between the specified indices."""
return self._values[start:stop]
@@ -267,3 +426,160 @@ class RepeatedCompositeFieldContainer(BaseContainer):
raise TypeError('Can only compare repeated composite fields against '
'other repeated composite fields.')
return self._values == other._values
+
+
+class ScalarMap(MutableMapping):
+
+ """Simple, type-checked, dict-like container for holding repeated scalars."""
+
+ # Disallows assignment to other attributes.
+ __slots__ = ['_key_checker', '_value_checker', '_values', '_message_listener']
+
+ def __init__(self, message_listener, key_checker, value_checker):
+ """
+ Args:
+ message_listener: A MessageListener implementation.
+ The ScalarMap will call this object's Modified() method when it
+ is modified.
+ key_checker: A type_checkers.ValueChecker instance to run on keys
+ inserted into this container.
+ value_checker: A type_checkers.ValueChecker instance to run on values
+ inserted into this container.
+ """
+ self._message_listener = message_listener
+ self._key_checker = key_checker
+ self._value_checker = value_checker
+ self._values = {}
+
+ def __getitem__(self, key):
+ try:
+ return self._values[key]
+ except KeyError:
+ key = self._key_checker.CheckValue(key)
+ val = self._value_checker.DefaultValue()
+ self._values[key] = val
+ return val
+
+ def __contains__(self, item):
+ return item in self._values
+
+ # We need to override this explicitly, because our defaultdict-like behavior
+ # will make the default implementation (from our base class) always insert
+ # the key.
+ def get(self, key, default=None):
+ if key in self:
+ return self[key]
+ else:
+ return default
+
+ def __setitem__(self, key, value):
+ checked_key = self._key_checker.CheckValue(key)
+ checked_value = self._value_checker.CheckValue(value)
+ self._values[checked_key] = checked_value
+ self._message_listener.Modified()
+
+ def __delitem__(self, key):
+ del self._values[key]
+ self._message_listener.Modified()
+
+ def __len__(self):
+ return len(self._values)
+
+ def __iter__(self):
+ return iter(self._values)
+
+ def MergeFrom(self, other):
+ self._values.update(other._values)
+ self._message_listener.Modified()
+
+ # This is defined in the abstract base, but we can do it much more cheaply.
+ def clear(self):
+ self._values.clear()
+ self._message_listener.Modified()
+
+
+class MessageMap(MutableMapping):
+
+ """Simple, type-checked, dict-like container for with submessage values."""
+
+ # Disallows assignment to other attributes.
+ __slots__ = ['_key_checker', '_values', '_message_listener',
+ '_message_descriptor']
+
+ def __init__(self, message_listener, message_descriptor, key_checker):
+ """
+ Args:
+ message_listener: A MessageListener implementation.
+ The ScalarMap will call this object's Modified() method when it
+ is modified.
+ key_checker: A type_checkers.ValueChecker instance to run on keys
+ inserted into this container.
+ value_checker: A type_checkers.ValueChecker instance to run on values
+ inserted into this container.
+ """
+ self._message_listener = message_listener
+ self._message_descriptor = message_descriptor
+ self._key_checker = key_checker
+ self._values = {}
+
+ def __getitem__(self, key):
+ try:
+ return self._values[key]
+ except KeyError:
+ key = self._key_checker.CheckValue(key)
+ new_element = self._message_descriptor._concrete_class()
+ new_element._SetListener(self._message_listener)
+ self._values[key] = new_element
+ self._message_listener.Modified()
+
+ return new_element
+
+ def get_or_create(self, key):
+ """get_or_create() is an alias for getitem (ie. map[key]).
+
+ Args:
+ key: The key to get or create in the map.
+
+ This is useful in cases where you want to be explicit that the call is
+ mutating the map. This can avoid lint errors for statements like this
+ that otherwise would appear to be pointless statements:
+
+ msg.my_map[key]
+ """
+ return self[key]
+
+ # We need to override this explicitly, because our defaultdict-like behavior
+ # will make the default implementation (from our base class) always insert
+ # the key.
+ def get(self, key, default=None):
+ if key in self:
+ return self[key]
+ else:
+ return default
+
+ def __contains__(self, item):
+ return item in self._values
+
+ def __setitem__(self, key, value):
+ raise ValueError('May not set values directly, call my_map[key].foo = 5')
+
+ def __delitem__(self, key):
+ del self._values[key]
+ self._message_listener.Modified()
+
+ def __len__(self):
+ return len(self._values)
+
+ def __iter__(self):
+ return iter(self._values)
+
+ def MergeFrom(self, other):
+ for key in other:
+ self[key].MergeFrom(other[key])
+ # self._message_listener.Modified() not required here, because
+ # mutations to submessages already propagate.
+
+ # This is defined in the abstract base, but we can do it much more cheaply.
+ def clear(self):
+ self._values.clear()
+ self._message_listener.Modified()
diff --git a/python/google/protobuf/internal/cpp_message.py b/python/google/protobuf/internal/cpp_message.py
deleted file mode 100755
index 20457375..00000000
--- a/python/google/protobuf/internal/cpp_message.py
+++ /dev/null
@@ -1,667 +0,0 @@
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# https://developers.google.com/protocol-buffers/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Contains helper functions used to create protocol message classes from
-Descriptor objects at runtime backed by the protocol buffer C++ API.
-"""
-
-__author__ = 'petar@google.com (Petar Petrov)'
-
-import collections
-import operator
-
-import six
-import six.moves.copyreg
-
-from google.protobuf.internal import _net_proto2___python
-from google.protobuf.internal import enum_type_wrapper
-from google.protobuf import message
-
-
-_LABEL_REPEATED = _net_proto2___python.LABEL_REPEATED
-_LABEL_OPTIONAL = _net_proto2___python.LABEL_OPTIONAL
-_CPPTYPE_MESSAGE = _net_proto2___python.CPPTYPE_MESSAGE
-_TYPE_MESSAGE = _net_proto2___python.TYPE_MESSAGE
-
-
-def GetDescriptorPool():
- """Creates a new DescriptorPool C++ object."""
- return _net_proto2___python.NewCDescriptorPool()
-
-
-_pool = GetDescriptorPool()
-
-
-def GetFieldDescriptor(full_field_name):
- """Searches for a field descriptor given a full field name."""
- return _pool.FindFieldByName(full_field_name)
-
-
-def BuildFile(content):
- """Registers a new proto file in the underlying C++ descriptor pool."""
- _net_proto2___python.BuildFile(content)
-
-
-def GetExtensionDescriptor(full_extension_name):
- """Searches for extension descriptor given a full field name."""
- return _pool.FindExtensionByName(full_extension_name)
-
-
-def NewCMessage(full_message_name):
- """Creates a new C++ protocol message by its name."""
- return _net_proto2___python.NewCMessage(full_message_name)
-
-
-def ScalarProperty(cdescriptor):
- """Returns a scalar property for the given descriptor."""
-
- def Getter(self):
- return self._cmsg.GetScalar(cdescriptor)
-
- def Setter(self, value):
- self._cmsg.SetScalar(cdescriptor, value)
-
- return property(Getter, Setter)
-
-
-def CompositeProperty(cdescriptor, message_type):
- """Returns a Python property the given composite field."""
-
- def Getter(self):
- sub_message = self._composite_fields.get(cdescriptor.name, None)
- if sub_message is None:
- cmessage = self._cmsg.NewSubMessage(cdescriptor)
- sub_message = message_type._concrete_class(__cmessage=cmessage)
- self._composite_fields[cdescriptor.name] = sub_message
- return sub_message
-
- return property(Getter)
-
-
-class RepeatedScalarContainer(object):
- """Container for repeated scalar fields."""
-
- __slots__ = ['_message', '_cfield_descriptor', '_cmsg']
-
- def __init__(self, msg, cfield_descriptor):
- self._message = msg
- self._cmsg = msg._cmsg
- self._cfield_descriptor = cfield_descriptor
-
- def append(self, value):
- self._cmsg.AddRepeatedScalar(
- self._cfield_descriptor, value)
-
- def extend(self, sequence):
- for element in sequence:
- self.append(element)
-
- def insert(self, key, value):
- values = self[slice(None, None, None)]
- values.insert(key, value)
- self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, values)
-
- def remove(self, value):
- values = self[slice(None, None, None)]
- values.remove(value)
- self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, values)
-
- def __setitem__(self, key, value):
- values = self[slice(None, None, None)]
- values[key] = value
- self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, values)
-
- def __getitem__(self, key):
- return self._cmsg.GetRepeatedScalar(self._cfield_descriptor, key)
-
- def __delitem__(self, key):
- self._cmsg.DeleteRepeatedField(self._cfield_descriptor, key)
-
- def __len__(self):
- return len(self[slice(None, None, None)])
-
- def __eq__(self, other):
- if self is other:
- return True
- if not isinstance(other, collections.Sequence):
- raise TypeError(
- 'Can only compare repeated scalar fields against sequences.')
- # We are presumably comparing against some other sequence type.
- return other == self[slice(None, None, None)]
-
- def __ne__(self, other):
- return not self == other
-
- def __hash__(self):
- raise TypeError('unhashable object')
-
- def sort(self, *args, **kwargs):
- # Maintain compatibility with the previous interface.
- if 'sort_function' in kwargs:
- kwargs['cmp'] = kwargs.pop('sort_function')
- self._cmsg.AssignRepeatedScalar(self._cfield_descriptor,
- sorted(self, *args, **kwargs))
-
-
-def RepeatedScalarProperty(cdescriptor):
- """Returns a Python property the given repeated scalar field."""
-
- def Getter(self):
- container = self._composite_fields.get(cdescriptor.name, None)
- if container is None:
- container = RepeatedScalarContainer(self, cdescriptor)
- self._composite_fields[cdescriptor.name] = container
- return container
-
- def Setter(self, new_value):
- raise AttributeError('Assignment not allowed to repeated field '
- '"%s" in protocol message object.' % cdescriptor.name)
-
- doc = 'Magic attribute generated for "%s" proto field.' % cdescriptor.name
- return property(Getter, Setter, doc=doc)
-
-
-class RepeatedCompositeContainer(object):
- """Container for repeated composite fields."""
-
- __slots__ = ['_message', '_subclass', '_cfield_descriptor', '_cmsg']
-
- def __init__(self, msg, cfield_descriptor, subclass):
- self._message = msg
- self._cmsg = msg._cmsg
- self._subclass = subclass
- self._cfield_descriptor = cfield_descriptor
-
- def add(self, **kwargs):
- cmessage = self._cmsg.AddMessage(self._cfield_descriptor)
- return self._subclass(__cmessage=cmessage, __owner=self._message, **kwargs)
-
- def extend(self, elem_seq):
- """Extends by appending the given sequence of elements of the same type
- as this one, copying each individual message.
- """
- for message in elem_seq:
- self.add().MergeFrom(message)
-
- def remove(self, value):
- # TODO(protocol-devel): This is inefficient as it needs to generate a
- # message pointer for each message only to do index(). Move this to a C++
- # extension function.
- self.__delitem__(self[slice(None, None, None)].index(value))
-
- def MergeFrom(self, other):
- for message in other[:]:
- self.add().MergeFrom(message)
-
- def __getitem__(self, key):
- cmessages = self._cmsg.GetRepeatedMessage(
- self._cfield_descriptor, key)
- subclass = self._subclass
- if not isinstance(cmessages, list):
- return subclass(__cmessage=cmessages, __owner=self._message)
-
- return [subclass(__cmessage=m, __owner=self._message) for m in cmessages]
-
- def __delitem__(self, key):
- self._cmsg.DeleteRepeatedField(
- self._cfield_descriptor, key)
-
- def __len__(self):
- return self._cmsg.FieldLength(self._cfield_descriptor)
-
- def __eq__(self, other):
- """Compares the current instance with another one."""
- if self is other:
- return True
- if not isinstance(other, self.__class__):
- raise TypeError('Can only compare repeated composite fields against '
- 'other repeated composite fields.')
- messages = self[slice(None, None, None)]
- other_messages = other[slice(None, None, None)]
- return messages == other_messages
-
- def __hash__(self):
- raise TypeError('unhashable object')
-
- def sort(self, cmp=None, key=None, reverse=False, **kwargs):
- # Maintain compatibility with the old interface.
- if cmp is None and 'sort_function' in kwargs:
- cmp = kwargs.pop('sort_function')
-
- # The cmp function, if provided, is passed the results of the key function,
- # so we only need to wrap one of them.
- if key is None:
- index_key = self.__getitem__
- else:
- index_key = lambda i: key(self[i])
-
- # Sort the list of current indexes by the underlying object.
- indexes = list(range(len(self)))
- indexes.sort(cmp=cmp, key=index_key, reverse=reverse)
-
- # Apply the transposition.
- for dest, src in enumerate(indexes):
- if dest == src:
- continue
- self._cmsg.SwapRepeatedFieldElements(self._cfield_descriptor, dest, src)
- # Don't swap the same value twice.
- indexes[src] = src
-
-
-def RepeatedCompositeProperty(cdescriptor, message_type):
- """Returns a Python property for the given repeated composite field."""
-
- def Getter(self):
- container = self._composite_fields.get(cdescriptor.name, None)
- if container is None:
- container = RepeatedCompositeContainer(
- self, cdescriptor, message_type._concrete_class)
- self._composite_fields[cdescriptor.name] = container
- return container
-
- def Setter(self, new_value):
- raise AttributeError('Assignment not allowed to repeated field '
- '"%s" in protocol message object.' % cdescriptor.name)
-
- doc = 'Magic attribute generated for "%s" proto field.' % cdescriptor.name
- return property(Getter, Setter, doc=doc)
-
-
-class ExtensionDict(object):
- """Extension dictionary added to each protocol message."""
-
- def __init__(self, msg):
- self._message = msg
- self._cmsg = msg._cmsg
- self._values = {}
-
- def __setitem__(self, extension, value):
- from google.protobuf import descriptor
- if not isinstance(extension, descriptor.FieldDescriptor):
- raise KeyError('Bad extension %r.' % (extension,))
- cdescriptor = extension._cdescriptor
- if (cdescriptor.label != _LABEL_OPTIONAL or
- cdescriptor.cpp_type == _CPPTYPE_MESSAGE):
- raise TypeError('Extension %r is repeated and/or a composite type.' % (
- extension.full_name,))
- self._cmsg.SetScalar(cdescriptor, value)
- self._values[extension] = value
-
- def __getitem__(self, extension):
- from google.protobuf import descriptor
- if not isinstance(extension, descriptor.FieldDescriptor):
- raise KeyError('Bad extension %r.' % (extension,))
-
- cdescriptor = extension._cdescriptor
- if (cdescriptor.label != _LABEL_REPEATED and
- cdescriptor.cpp_type != _CPPTYPE_MESSAGE):
- return self._cmsg.GetScalar(cdescriptor)
-
- ext = self._values.get(extension, None)
- if ext is not None:
- return ext
-
- ext = self._CreateNewHandle(extension)
- self._values[extension] = ext
- return ext
-
- def ClearExtension(self, extension):
- from google.protobuf import descriptor
- if not isinstance(extension, descriptor.FieldDescriptor):
- raise KeyError('Bad extension %r.' % (extension,))
- self._cmsg.ClearFieldByDescriptor(extension._cdescriptor)
- if extension in self._values:
- del self._values[extension]
-
- def HasExtension(self, extension):
- from google.protobuf import descriptor
- if not isinstance(extension, descriptor.FieldDescriptor):
- raise KeyError('Bad extension %r.' % (extension,))
- return self._cmsg.HasFieldByDescriptor(extension._cdescriptor)
-
- def _FindExtensionByName(self, name):
- """Tries to find a known extension with the specified name.
-
- Args:
- name: Extension full name.
-
- Returns:
- Extension field descriptor.
- """
- return self._message._extensions_by_name.get(name, None)
-
- def _CreateNewHandle(self, extension):
- cdescriptor = extension._cdescriptor
- if (cdescriptor.label != _LABEL_REPEATED and
- cdescriptor.cpp_type == _CPPTYPE_MESSAGE):
- cmessage = self._cmsg.NewSubMessage(cdescriptor)
- return extension.message_type._concrete_class(__cmessage=cmessage)
-
- if cdescriptor.label == _LABEL_REPEATED:
- if cdescriptor.cpp_type == _CPPTYPE_MESSAGE:
- return RepeatedCompositeContainer(
- self._message, cdescriptor, extension.message_type._concrete_class)
- else:
- return RepeatedScalarContainer(self._message, cdescriptor)
- # This shouldn't happen!
- assert False
- return None
-
-
-def NewMessage(bases, message_descriptor, dictionary):
- """Creates a new protocol message *class*."""
- _AddClassAttributesForNestedExtensions(message_descriptor, dictionary)
- _AddEnumValues(message_descriptor, dictionary)
- _AddDescriptors(message_descriptor, dictionary)
- return bases
-
-
-def InitMessage(message_descriptor, cls):
- """Constructs a new message instance (called before instance's __init__)."""
- cls._extensions_by_name = {}
- _AddInitMethod(message_descriptor, cls)
- _AddMessageMethods(message_descriptor, cls)
- _AddPropertiesForExtensions(message_descriptor, cls)
- six.moves.copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__()))
-
-
-def _AddDescriptors(message_descriptor, dictionary):
- """Sets up a new protocol message class dictionary.
-
- Args:
- message_descriptor: A Descriptor instance describing this message type.
- dictionary: Class dictionary to which we'll add a '__slots__' entry.
- """
- dictionary['__descriptors'] = {}
- for field in message_descriptor.fields:
- dictionary['__descriptors'][field.name] = GetFieldDescriptor(
- field.full_name)
-
- dictionary['__slots__'] = list(dictionary['__descriptors'].keys()) + [
- '_cmsg', '_owner', '_composite_fields', 'Extensions', '_HACK_REFCOUNTS']
-
-
-def _AddEnumValues(message_descriptor, dictionary):
- """Sets class-level attributes for all enum fields defined in this message.
-
- Args:
- message_descriptor: Descriptor object for this message type.
- dictionary: Class dictionary that should be populated.
- """
- for enum_type in message_descriptor.enum_types:
- dictionary[enum_type.name] = enum_type_wrapper.EnumTypeWrapper(enum_type)
- for enum_value in enum_type.values:
- dictionary[enum_value.name] = enum_value.number
-
-
-def _AddClassAttributesForNestedExtensions(message_descriptor, dictionary):
- """Adds class attributes for the nested extensions."""
- extension_dict = message_descriptor.extensions_by_name
- for extension_name, extension_field in extension_dict.items():
- assert extension_name not in dictionary
- dictionary[extension_name] = extension_field
-
-
-def _AddInitMethod(message_descriptor, cls):
- """Adds an __init__ method to cls."""
-
- # Create and attach message field properties to the message class.
- # This can be done just once per message class, since property setters and
- # getters are passed the message instance.
- # This makes message instantiation extremely fast, and at the same time it
- # doesn't require the creation of property objects for each message instance,
- # which saves a lot of memory.
- for field in message_descriptor.fields:
- field_cdescriptor = cls.__descriptors[field.name]
- if field.label == _LABEL_REPEATED:
- if field.cpp_type == _CPPTYPE_MESSAGE:
- value = RepeatedCompositeProperty(field_cdescriptor, field.message_type)
- else:
- value = RepeatedScalarProperty(field_cdescriptor)
- elif field.cpp_type == _CPPTYPE_MESSAGE:
- value = CompositeProperty(field_cdescriptor, field.message_type)
- else:
- value = ScalarProperty(field_cdescriptor)
- setattr(cls, field.name, value)
-
- # Attach a constant with the field number.
- constant_name = field.name.upper() + '_FIELD_NUMBER'
- setattr(cls, constant_name, field.number)
-
- def Init(self, **kwargs):
- """Message constructor."""
- cmessage = kwargs.pop('__cmessage', None)
- if cmessage:
- self._cmsg = cmessage
- else:
- self._cmsg = NewCMessage(message_descriptor.full_name)
-
- # Keep a reference to the owner, as the owner keeps a reference to the
- # underlying protocol buffer message.
- owner = kwargs.pop('__owner', None)
- if owner:
- self._owner = owner
-
- if message_descriptor.is_extendable:
- self.Extensions = ExtensionDict(self)
- else:
- # Reference counting in the C++ code is broken and depends on
- # the Extensions reference to keep this object alive during unit
- # tests (see b/4856052). Remove this once b/4945904 is fixed.
- self._HACK_REFCOUNTS = self
- self._composite_fields = {}
-
- for field_name, field_value in kwargs.items():
- field_cdescriptor = self.__descriptors.get(field_name, None)
- if not field_cdescriptor:
- raise ValueError('Protocol message has no "%s" field.' % field_name)
- if field_cdescriptor.label == _LABEL_REPEATED:
- if field_cdescriptor.cpp_type == _CPPTYPE_MESSAGE:
- field_name = getattr(self, field_name)
- for val in field_value:
- field_name.add().MergeFrom(val)
- else:
- getattr(self, field_name).extend(field_value)
- elif field_cdescriptor.cpp_type == _CPPTYPE_MESSAGE:
- getattr(self, field_name).MergeFrom(field_value)
- else:
- setattr(self, field_name, field_value)
-
- Init.__module__ = None
- Init.__doc__ = None
- cls.__init__ = Init
-
-
-def _IsMessageSetExtension(field):
- """Checks if a field is a message set extension."""
- return (field.is_extension and
- field.containing_type.has_options and
- field.containing_type.GetOptions().message_set_wire_format and
- field.type == _TYPE_MESSAGE and
- field.message_type == field.extension_scope and
- field.label == _LABEL_OPTIONAL)
-
-
-def _AddMessageMethods(message_descriptor, cls):
- """Adds the methods to a protocol message class."""
- if message_descriptor.is_extendable:
-
- def ClearExtension(self, extension):
- self.Extensions.ClearExtension(extension)
-
- def HasExtension(self, extension):
- return self.Extensions.HasExtension(extension)
-
- def HasField(self, field_name):
- return self._cmsg.HasField(field_name)
-
- def ClearField(self, field_name):
- child_cmessage = None
- if field_name in self._composite_fields:
- child_field = self._composite_fields[field_name]
- del self._composite_fields[field_name]
-
- child_cdescriptor = self.__descriptors[field_name]
- # TODO(anuraag): Support clearing repeated message fields as well.
- if (child_cdescriptor.label != _LABEL_REPEATED and
- child_cdescriptor.cpp_type == _CPPTYPE_MESSAGE):
- child_field._owner = None
- child_cmessage = child_field._cmsg
-
- if child_cmessage is not None:
- self._cmsg.ClearField(field_name, child_cmessage)
- else:
- self._cmsg.ClearField(field_name)
-
- def Clear(self):
- cmessages_to_release = []
- for field_name, child_field in self._composite_fields.items():
- child_cdescriptor = self.__descriptors[field_name]
- # TODO(anuraag): Support clearing repeated message fields as well.
- if (child_cdescriptor.label != _LABEL_REPEATED and
- child_cdescriptor.cpp_type == _CPPTYPE_MESSAGE):
- child_field._owner = None
- cmessages_to_release.append((child_cdescriptor, child_field._cmsg))
- self._composite_fields.clear()
- self._cmsg.Clear(cmessages_to_release)
-
- def IsInitialized(self, errors=None):
- if self._cmsg.IsInitialized():
- return True
- if errors is not None:
- errors.extend(self.FindInitializationErrors());
- return False
-
- def SerializeToString(self):
- if not self.IsInitialized():
- raise message.EncodeError(
- 'Message %s is missing required fields: %s' % (
- self._cmsg.full_name, ','.join(self.FindInitializationErrors())))
- return self._cmsg.SerializeToString()
-
- def SerializePartialToString(self):
- return self._cmsg.SerializePartialToString()
-
- def ParseFromString(self, serialized):
- self.Clear()
- self.MergeFromString(serialized)
-
- def MergeFromString(self, serialized):
- byte_size = self._cmsg.MergeFromString(serialized)
- if byte_size < 0:
- raise message.DecodeError('Unable to merge from string.')
- return byte_size
-
- def MergeFrom(self, msg):
- if not isinstance(msg, cls):
- raise TypeError(
- "Parameter to MergeFrom() must be instance of same class: "
- "expected %s got %s." % (cls.__name__, type(msg).__name__))
- self._cmsg.MergeFrom(msg._cmsg)
-
- def CopyFrom(self, msg):
- self._cmsg.CopyFrom(msg._cmsg)
-
- def ByteSize(self):
- return self._cmsg.ByteSize()
-
- def SetInParent(self):
- return self._cmsg.SetInParent()
-
- def ListFields(self):
- all_fields = []
- field_list = self._cmsg.ListFields()
- fields_by_name = cls.DESCRIPTOR.fields_by_name
- for is_extension, field_name in field_list:
- if is_extension:
- extension = cls._extensions_by_name[field_name]
- all_fields.append((extension, self.Extensions[extension]))
- else:
- field_descriptor = fields_by_name[field_name]
- all_fields.append(
- (field_descriptor, getattr(self, field_name)))
- all_fields.sort(key=lambda item: item[0].number)
- return all_fields
-
- def FindInitializationErrors(self):
- return self._cmsg.FindInitializationErrors()
-
- def __str__(self):
- return str(self._cmsg)
-
- def __eq__(self, other):
- if self is other:
- return True
- if not isinstance(other, self.__class__):
- return False
- return self.ListFields() == other.ListFields()
-
- def __ne__(self, other):
- return not self == other
-
- def __hash__(self):
- raise TypeError('unhashable object')
-
- def __unicode__(self):
- # Lazy import to prevent circular import when text_format imports this file.
- from google.protobuf import text_format
- return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
-
- # Attach the local methods to the message class.
- for key, value in locals().copy().items():
- if key not in ('key', 'value', '__builtins__', '__name__', '__doc__'):
- setattr(cls, key, value)
-
- # Static methods:
-
- def RegisterExtension(extension_handle):
- extension_handle.containing_type = cls.DESCRIPTOR
- cls._extensions_by_name[extension_handle.full_name] = extension_handle
-
- if _IsMessageSetExtension(extension_handle):
- # MessageSet extension. Also register under type name.
- cls._extensions_by_name[
- extension_handle.message_type.full_name] = extension_handle
- cls.RegisterExtension = staticmethod(RegisterExtension)
-
- def FromString(string):
- msg = cls()
- msg.MergeFromString(string)
- return msg
- cls.FromString = staticmethod(FromString)
-
-
-
-def _AddPropertiesForExtensions(message_descriptor, cls):
- """Adds properties for all fields in this protocol message type."""
- extension_dict = message_descriptor.extensions_by_name
- for extension_name, extension_field in extension_dict.items():
- constant_name = extension_name.upper() + '_FIELD_NUMBER'
- setattr(cls, constant_name, extension_field.number)
diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py
index 6b72adef..130386f2 100755
--- a/python/google/protobuf/internal/decoder.py
+++ b/python/google/protobuf/internal/decoder.py
@@ -602,9 +602,6 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
if value is None:
value = field_dict.setdefault(key, new_default(message))
while 1:
- value = field_dict.get(key)
- if value is None:
- value = field_dict.setdefault(key, new_default(message))
# Read length.
(size, pos) = local_DecodeVarint(buffer, pos)
new_pos = pos + size
@@ -717,6 +714,50 @@ def MessageSetItemDecoder(extensions_by_number):
return DecodeItem
# --------------------------------------------------------------------
+
+def MapDecoder(field_descriptor, new_default, is_message_map):
+ """Returns a decoder for a map field."""
+
+ key = field_descriptor
+ tag_bytes = encoder.TagBytes(field_descriptor.number,
+ wire_format.WIRETYPE_LENGTH_DELIMITED)
+ tag_len = len(tag_bytes)
+ local_DecodeVarint = _DecodeVarint
+ # Can't read _concrete_class yet; might not be initialized.
+ message_type = field_descriptor.message_type
+
+ def DecodeMap(buffer, pos, end, message, field_dict):
+ submsg = message_type._concrete_class()
+ value = field_dict.get(key)
+ if value is None:
+ value = field_dict.setdefault(key, new_default(message))
+ while 1:
+ # Read length.
+ (size, pos) = local_DecodeVarint(buffer, pos)
+ new_pos = pos + size
+ if new_pos > end:
+ raise _DecodeError('Truncated message.')
+ # Read sub-message.
+ submsg.Clear()
+ if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
+ # The only reason _InternalParse would return early is if it
+ # encountered an end-group tag.
+ raise _DecodeError('Unexpected end-group tag.')
+
+ if is_message_map:
+ value[submsg.key].MergeFrom(submsg.value)
+ else:
+ value[submsg.key] = submsg.value
+
+ # Predict that the next tag is another copy of the same repeated field.
+ pos = new_pos + tag_len
+ if buffer[new_pos:pos] != tag_bytes or new_pos == end:
+ # Prediction failed. Return.
+ return new_pos
+
+ return DecodeMap
+
+# --------------------------------------------------------------------
# Optimization is not as heavy here because calls to SkipField() are rare,
# except for handling end-group tags.
diff --git a/python/google/protobuf/internal/descriptor_database_test.py b/python/google/protobuf/internal/descriptor_database_test.py
index ad13f889..3241cb72 100644
--- a/python/google/protobuf/internal/descriptor_database_test.py
+++ b/python/google/protobuf/internal/descriptor_database_test.py
@@ -1,4 +1,4 @@
-#! /usr/bin/python
+#! /usr/bin/env python
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
@@ -35,7 +35,6 @@
__author__ = 'matthewtoia@google.com (Matt Toia)'
import unittest
-
from google.protobuf import descriptor_pb2
from google.protobuf.internal import factory_test2_pb2
from google.protobuf import descriptor_database
diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py
index fa1a511a..64b5d172 100644
--- a/python/google/protobuf/internal/descriptor_pool_test.py
+++ b/python/google/protobuf/internal/descriptor_pool_test.py
@@ -1,4 +1,4 @@
-#! /usr/bin/python
+#! /usr/bin/env python
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
@@ -37,6 +37,7 @@ __author__ = 'matthewtoia@google.com (Matt Toia)'
import os
import unittest
+import unittest
from google.protobuf import unittest_pb2
from google.protobuf import descriptor_pb2
from google.protobuf.internal import api_implementation
@@ -226,6 +227,13 @@ class DescriptorPoolTest(unittest.TestCase):
db.Add(self.factory_test2_fd)
self.testFindMessageTypeByName()
+ def testAddSerializedFile(self):
+ db = descriptor_database.DescriptorDatabase()
+ self.pool = descriptor_pool.DescriptorPool(db)
+ self.pool.AddSerializedFile(self.factory_test1_fd.SerializeToString())
+ self.pool.AddSerializedFile(self.factory_test2_fd.SerializeToString())
+ self.testFindMessageTypeByName()
+
def testComplexNesting(self):
test1_desc = descriptor_pb2.FileDescriptorProto.FromString(
descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb)
diff --git a/python/google/protobuf/internal/descriptor_python_test.py b/python/google/protobuf/internal/descriptor_python_test.py
deleted file mode 100644
index 573c1b9d..00000000
--- a/python/google/protobuf/internal/descriptor_python_test.py
+++ /dev/null
@@ -1,54 +0,0 @@
-#! /usr/bin/python
-#
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# https://developers.google.com/protocol-buffers/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Unittest for descriptor.py for the pure Python implementation."""
-
-import os
-import unittest
-os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
-
-# We must set the implementation version above before the google3 imports.
-# pylint: disable=g-import-not-at-top
-from google.protobuf.internal import api_implementation
-# Run all tests from the original module by putting them in our namespace.
-# pylint: disable=wildcard-import
-from google.protobuf.internal.descriptor_test import *
-
-
-class ConfirmPurePythonTest(unittest.TestCase):
-
- def testImplementationSetting(self):
- self.assertEqual('python', api_implementation.Type())
-
-
-if __name__ == '__main__':
- unittest.main()
diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py
index e1506fa4..a40ec0e4 100755
--- a/python/google/protobuf/internal/descriptor_test.py
+++ b/python/google/protobuf/internal/descriptor_test.py
@@ -34,13 +34,16 @@
__author__ = 'robinson@google.com (Will Robinson)'
-import unittest
+import sys
+import unittest
from google.protobuf import unittest_custom_options_pb2
from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_pb2
from google.protobuf import descriptor_pb2
+from google.protobuf.internal import api_implementation
from google.protobuf import descriptor
+from google.protobuf import symbol_database
from google.protobuf import text_format
@@ -52,41 +55,28 @@ name: 'TestEmptyMessage'
class DescriptorTest(unittest.TestCase):
def setUp(self):
- self.my_file = descriptor.FileDescriptor(
+ file_proto = descriptor_pb2.FileDescriptorProto(
name='some/filename/some.proto',
- package='protobuf_unittest'
- )
- self.my_enum = descriptor.EnumDescriptor(
- name='ForeignEnum',
- full_name='protobuf_unittest.ForeignEnum',
- filename=None,
- file=self.my_file,
- values=[
- descriptor.EnumValueDescriptor(name='FOREIGN_FOO', index=0, number=4),
- descriptor.EnumValueDescriptor(name='FOREIGN_BAR', index=1, number=5),
- descriptor.EnumValueDescriptor(name='FOREIGN_BAZ', index=2, number=6),
- ])
- self.my_message = descriptor.Descriptor(
- name='NestedMessage',
- full_name='protobuf_unittest.TestAllTypes.NestedMessage',
- filename=None,
- file=self.my_file,
- containing_type=None,
- fields=[
- descriptor.FieldDescriptor(
- name='bb',
- full_name='protobuf_unittest.TestAllTypes.NestedMessage.bb',
- index=0, number=1,
- type=5, cpp_type=1, label=1,
- has_default_value=False, default_value=0,
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None),
- ],
- nested_types=[],
- enum_types=[
- self.my_enum,
- ],
- extensions=[])
+ package='protobuf_unittest')
+ message_proto = file_proto.message_type.add(
+ name='NestedMessage')
+ message_proto.field.add(
+ name='bb',
+ number=1,
+ type=descriptor_pb2.FieldDescriptorProto.TYPE_INT32,
+ label=descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL)
+ enum_proto = message_proto.enum_type.add(
+ name='ForeignEnum')
+ enum_proto.value.add(name='FOREIGN_FOO', number=4)
+ enum_proto.value.add(name='FOREIGN_BAR', number=5)
+ enum_proto.value.add(name='FOREIGN_BAZ', number=6)
+
+ descriptor_pool = symbol_database.Default().pool
+ descriptor_pool.Add(file_proto)
+ self.my_file = descriptor_pool.FindFileByName(file_proto.name)
+ self.my_message = self.my_file.message_types_by_name[message_proto.name]
+ self.my_enum = self.my_message.enum_types_by_name[enum_proto.name]
+
self.my_method = descriptor.MethodDescriptor(
name='Bar',
full_name='protobuf_unittest.TestService.Bar',
@@ -174,6 +164,11 @@ class DescriptorTest(unittest.TestCase):
self.assertEqual(unittest_custom_options_pb2.METHODOPT1_VAL2,
method_options.Extensions[method_opt1])
+ message_descriptor = (
+ unittest_custom_options_pb2.DummyMessageContainingEnum.DESCRIPTOR)
+ self.assertTrue(file_descriptor.has_options)
+ self.assertFalse(message_descriptor.has_options)
+
def testDifferentCustomOptionTypes(self):
kint32min = -2**31
kint64min = -2**63
@@ -395,6 +390,108 @@ class DescriptorTest(unittest.TestCase):
self.assertEqual(self.my_file.name, 'some/filename/some.proto')
self.assertEqual(self.my_file.package, 'protobuf_unittest')
+ @unittest.skipIf(
+ api_implementation.Type() != 'cpp' or api_implementation.Version() != 2,
+ 'Immutability of descriptors is only enforced in v2 implementation')
+ def testImmutableCppDescriptor(self):
+ message_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
+ with self.assertRaises(AttributeError):
+ message_descriptor.fields_by_name = None
+ with self.assertRaises(TypeError):
+ message_descriptor.fields_by_name['Another'] = None
+ with self.assertRaises(TypeError):
+ message_descriptor.fields.append(None)
+
+
+class GeneratedDescriptorTest(unittest.TestCase):
+ """Tests for the properties of descriptors in generated code."""
+
+ def CheckMessageDescriptor(self, message_descriptor):
+ # Basic properties
+ self.assertEqual(message_descriptor.name, 'TestAllTypes')
+ self.assertEqual(message_descriptor.full_name,
+ 'protobuf_unittest.TestAllTypes')
+ # Test equality and hashability
+ self.assertEqual(message_descriptor, message_descriptor)
+ self.assertEqual(message_descriptor.fields[0].containing_type,
+ message_descriptor)
+ self.assertIn(message_descriptor, [message_descriptor])
+ self.assertIn(message_descriptor, {message_descriptor: None})
+ # Test field containers
+ self.CheckDescriptorSequence(message_descriptor.fields)
+ self.CheckDescriptorMapping(message_descriptor.fields_by_name)
+ self.CheckDescriptorMapping(message_descriptor.fields_by_number)
+
+ def CheckFieldDescriptor(self, field_descriptor):
+ # Basic properties
+ self.assertEqual(field_descriptor.name, 'optional_int32')
+ self.assertEqual(field_descriptor.full_name,
+ 'protobuf_unittest.TestAllTypes.optional_int32')
+ self.assertEqual(field_descriptor.containing_type.name, 'TestAllTypes')
+ # Test equality and hashability
+ self.assertEqual(field_descriptor, field_descriptor)
+ self.assertEqual(
+ field_descriptor.containing_type.fields_by_name['optional_int32'],
+ field_descriptor)
+ self.assertIn(field_descriptor, [field_descriptor])
+ self.assertIn(field_descriptor, {field_descriptor: None})
+
+ def CheckDescriptorSequence(self, sequence):
+ # Verifies that a property like 'messageDescriptor.fields' has all the
+ # properties of an immutable abc.Sequence.
+ self.assertGreater(len(sequence), 0) # Sized
+ self.assertEqual(len(sequence), len(list(sequence))) # Iterable
+ item = sequence[0]
+ self.assertEqual(item, sequence[0])
+ self.assertIn(item, sequence) # Container
+ self.assertEqual(sequence.index(item), 0)
+ self.assertEqual(sequence.count(item), 1)
+ reversed_iterator = reversed(sequence)
+ self.assertEqual(list(reversed_iterator), list(sequence)[::-1])
+ self.assertRaises(StopIteration, next, reversed_iterator)
+
+ def CheckDescriptorMapping(self, mapping):
+ # Verifies that a property like 'messageDescriptor.fields' has all the
+ # properties of an immutable abc.Mapping.
+ self.assertGreater(len(mapping), 0) # Sized
+ self.assertEqual(len(mapping), len(list(mapping))) # Iterable
+ if sys.version_info.major >= 3:
+ key, item = next(iter(mapping.items()))
+ else:
+ key, item = mapping.items()[0]
+ self.assertIn(key, mapping) # Container
+ self.assertEqual(mapping.get(key), item)
+ # keys(), iterkeys() &co
+ item = (next(iter(mapping.keys())), next(iter(mapping.values())))
+ self.assertEqual(item, next(iter(mapping.items())))
+ if sys.version_info.major < 3:
+ def CheckItems(seq, iterator):
+ self.assertEqual(next(iterator), seq[0])
+ self.assertEqual(list(iterator), seq[1:])
+ CheckItems(mapping.keys(), mapping.iterkeys())
+ CheckItems(mapping.values(), mapping.itervalues())
+ CheckItems(mapping.items(), mapping.iteritems())
+
+ def testDescriptor(self):
+ message_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
+ self.CheckMessageDescriptor(message_descriptor)
+ field_descriptor = message_descriptor.fields_by_name['optional_int32']
+ self.CheckFieldDescriptor(field_descriptor)
+
+ def testCppDescriptorContainer(self):
+ # Check that the collection is still valid even if the parent disappeared.
+ enum = unittest_pb2.TestAllTypes.DESCRIPTOR.enum_types_by_name['NestedEnum']
+ values = enum.values
+ del enum
+ self.assertEqual('FOO', values[0].name)
+
+ def testCppDescriptorContainer_Iterator(self):
+ # Same test with the iterator
+ enum = unittest_pb2.TestAllTypes.DESCRIPTOR.enum_types_by_name['NestedEnum']
+ values_iter = iter(enum.values)
+ del enum
+ self.assertEqual('FOO', next(values_iter).name)
+
class DescriptorCopyToProtoTest(unittest.TestCase):
"""Tests for CopyTo functions of Descriptor."""
@@ -589,10 +686,12 @@ class DescriptorCopyToProtoTest(unittest.TestCase):
output_type: '.protobuf_unittest.BarResponse'
>
"""
- self._InternalTestCopyToProto(
- unittest_pb2.TestService.DESCRIPTOR,
- descriptor_pb2.ServiceDescriptorProto,
- TEST_SERVICE_ASCII)
+ # TODO(rocking): enable this test after the proto descriptor change is
+ # checked in.
+ #self._InternalTestCopyToProto(
+ # unittest_pb2.TestService.DESCRIPTOR,
+ # descriptor_pb2.ServiceDescriptorProto,
+ # TEST_SERVICE_ASCII)
class MakeDescriptorTest(unittest.TestCase):
diff --git a/python/google/protobuf/internal/encoder.py b/python/google/protobuf/internal/encoder.py
index fa22a9dd..d72cd29d 100755
--- a/python/google/protobuf/internal/encoder.py
+++ b/python/google/protobuf/internal/encoder.py
@@ -313,7 +313,7 @@ def MessageSizer(field_number, is_repeated, is_packed):
# --------------------------------------------------------------------
-# MessageSet is special.
+# MessageSet is special: it needs custom logic to compute its size properly.
def MessageSetItemSizer(field_number):
@@ -338,6 +338,32 @@ def MessageSetItemSizer(field_number):
return FieldSize
+# --------------------------------------------------------------------
+# Map is special: it needs custom logic to compute its size properly.
+
+
+def MapSizer(field_descriptor):
+ """Returns a sizer for a map field."""
+
+ # Can't look at field_descriptor.message_type._concrete_class because it may
+ # not have been initialized yet.
+ message_type = field_descriptor.message_type
+ message_sizer = MessageSizer(field_descriptor.number, False, False)
+
+ def FieldSize(map_value):
+ total = 0
+ for key in map_value:
+ value = map_value[key]
+ # It's wasteful to create the messages and throw them away one second
+ # later since we'll do the same for the actual encode. But there's not an
+ # obvious way to avoid this within the current design without tons of code
+ # duplication.
+ entry_msg = message_type._concrete_class(key=key, value=value)
+ total += message_sizer(entry_msg)
+ return total
+
+ return FieldSize
+
# ====================================================================
# Encoders!
@@ -770,3 +796,30 @@ def MessageSetItemEncoder(field_number):
return write(end_bytes)
return EncodeField
+
+
+# --------------------------------------------------------------------
+# As before, Map is special.
+
+
+def MapEncoder(field_descriptor):
+ """Encoder for extensions of MessageSet.
+
+ Maps always have a wire format like this:
+ message MapEntry {
+ key_type key = 1;
+ value_type value = 2;
+ }
+ repeated MapEntry map = N;
+ """
+ # Can't look at field_descriptor.message_type._concrete_class because it may
+ # not have been initialized yet.
+ message_type = field_descriptor.message_type
+ encode_message = MessageEncoder(field_descriptor.number, False, False)
+
+ def EncodeField(write, value):
+ for key in value:
+ entry_msg = message_type._concrete_class(key=key, value=value[key])
+ encode_message(write, entry_msg)
+
+ return EncodeField
diff --git a/python/google/protobuf/internal/generator_test.py b/python/google/protobuf/internal/generator_test.py
index afcf6227..cc67f19f 100755
--- a/python/google/protobuf/internal/generator_test.py
+++ b/python/google/protobuf/internal/generator_test.py
@@ -1,4 +1,4 @@
-#! /usr/bin/python
+#! /usr/bin/env python
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
@@ -42,7 +42,6 @@ further ensures that we can use Python protocol message objects as we expect.
__author__ = 'robinson@google.com (Will Robinson)'
import unittest
-
from google.protobuf.internal import test_bad_identifiers_pb2
from google.protobuf import unittest_custom_options_pb2
from google.protobuf import unittest_import_pb2
@@ -154,7 +153,7 @@ class GeneratorTest(unittest.TestCase):
# extension and for its value to be set to -789.
def testNestedTypes(self):
- self.assertEqual(
+ self.assertEquals(
set(unittest_pb2.TestAllTypes.DESCRIPTOR.nested_types),
set([
unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR,
@@ -292,7 +291,7 @@ class GeneratorTest(unittest.TestCase):
self.assertIs(desc.oneofs[0], desc.oneofs_by_name['oneof_field'])
nested_names = set(['oneof_uint32', 'oneof_nested_message',
'oneof_string', 'oneof_bytes'])
- self.assertSameElements(
+ self.assertItemsEqual(
nested_names,
[field.name for field in desc.oneofs[0].fields])
for field_name, field_desc in desc.fields_by_name.items():
diff --git a/python/google/protobuf/internal/import_test_package/BUILD b/python/google/protobuf/internal/import_test_package/BUILD
deleted file mode 100644
index 90e59505..00000000
--- a/python/google/protobuf/internal/import_test_package/BUILD
+++ /dev/null
@@ -1,27 +0,0 @@
-# Description:
-# An example package that contains nested protos that are imported from
-# __init__.py. See testPackageInitializationImport in reflection_test.py for
-# details.
-
-package(
- default_visibility = ["//net/proto2/python/internal:__pkg__"],
-)
-
-proto_library(
- name = "inner_proto",
- srcs = ["inner.proto"],
- py_api_version = 2,
-)
-
-proto_library(
- name = "outer_proto",
- srcs = ["outer.proto"],
- py_api_version = 2,
- deps = [":inner_proto"],
-)
-
-py_library(
- name = "import_test_package",
- srcs = ["__init__.py"],
- deps = [":outer_proto"],
-)
diff --git a/python/google/protobuf/internal/message_factory_python_test.py b/python/google/protobuf/internal/message_factory_python_test.py
deleted file mode 100644
index eeb164b1..00000000
--- a/python/google/protobuf/internal/message_factory_python_test.py
+++ /dev/null
@@ -1,54 +0,0 @@
-#! /usr/bin/python
-#
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# https://developers.google.com/protocol-buffers/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Tests for ..public.message_factory for the pure Python implementation."""
-
-import os
-import unittest
-os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
-
-# We must set the implementation version above before the google3 imports.
-# pylint: disable=g-import-not-at-top
-from google.protobuf.internal import api_implementation
-# Run all tests from the original module by putting them in our namespace.
-# pylint: disable=wildcard-import
-from google.protobuf.internal.message_factory_test import *
-
-
-class ConfirmPurePythonTest(unittest.TestCase):
-
- def testImplementationSetting(self):
- self.assertEqual('python', api_implementation.Type())
-
-
-if __name__ == '__main__':
- unittest.main()
diff --git a/python/google/protobuf/internal/message_factory_test.py b/python/google/protobuf/internal/message_factory_test.py
index b33539a0..27a3f08b 100644
--- a/python/google/protobuf/internal/message_factory_test.py
+++ b/python/google/protobuf/internal/message_factory_test.py
@@ -1,4 +1,4 @@
-#! /usr/bin/python
+#! /usr/bin/env python
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
@@ -35,7 +35,6 @@
__author__ = 'matthewtoia@google.com (Matt Toia)'
import unittest
-
from google.protobuf import descriptor_pb2
from google.protobuf.internal import factory_test1_pb2
from google.protobuf.internal import factory_test2_pb2
diff --git a/python/google/protobuf/internal/message_python_test.py b/python/google/protobuf/internal/message_python_test.py
deleted file mode 100644
index ef57967b..00000000
--- a/python/google/protobuf/internal/message_python_test.py
+++ /dev/null
@@ -1,54 +0,0 @@
-#! /usr/bin/python
-#
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# https://developers.google.com/protocol-buffers/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Tests for ..public.message for the pure Python implementation."""
-
-import os
-import unittest
-os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
-
-# We must set the implementation version above before the google3 imports.
-# pylint: disable=g-import-not-at-top
-from google.protobuf.internal import api_implementation
-# Run all tests from the original module by putting them in our namespace.
-# pylint: disable=wildcard-import
-from google.protobuf.internal.message_test import *
-
-
-class ConfirmPurePythonTest(unittest.TestCase):
-
- def testImplementationSetting(self):
- self.assertEqual('python', api_implementation.Type())
-
-
-if __name__ == '__main__':
- unittest.main()
diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py
index e69c49b6..4dc92752 100755
--- a/python/google/protobuf/internal/message_test.py
+++ b/python/google/protobuf/internal/message_test.py
@@ -1,4 +1,4 @@
-#! /usr/bin/python
+#! /usr/bin/env python
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
@@ -48,9 +48,12 @@ import math
import operator
import pickle
import sys
-import unittest
+import unittest
+from google.protobuf.internal import _parameterized
+from google.protobuf import map_unittest_pb2
from google.protobuf import unittest_pb2
+from google.protobuf import unittest_proto3_arena_pb2
from google.protobuf.internal import api_implementation
from google.protobuf.internal import test_util
from google.protobuf import message
@@ -69,88 +72,72 @@ def IsNegInf(val):
return isinf(val) and (val < 0)
+@_parameterized.Parameters(
+ (unittest_pb2),
+ (unittest_proto3_arena_pb2))
class MessageTest(unittest.TestCase):
- def testBadUtf8String(self):
+ def testBadUtf8String(self, message_module):
if api_implementation.Type() != 'python':
self.skipTest("Skipping testBadUtf8String, currently only the python "
"api implementation raises UnicodeDecodeError when a "
"string field contains bad utf-8.")
bad_utf8_data = test_util.GoldenFileData('bad_utf8_string')
with self.assertRaises(UnicodeDecodeError) as context:
- unittest_pb2.TestAllTypes.FromString(bad_utf8_data)
- self.assertIn('field: protobuf_unittest.TestAllTypes.optional_string',
- str(context.exception))
-
- def testGoldenMessage(self):
- golden_data = test_util.GoldenFileData(
- 'golden_message_oneof_implemented')
- golden_message = unittest_pb2.TestAllTypes()
- golden_message.ParseFromString(golden_data)
- test_util.ExpectAllFieldsSet(self, golden_message)
- self.assertEqual(golden_data, golden_message.SerializeToString())
- golden_copy = copy.deepcopy(golden_message)
- self.assertEqual(golden_data, golden_copy.SerializeToString())
+ message_module.TestAllTypes.FromString(bad_utf8_data)
+ self.assertIn('TestAllTypes.optional_string', str(context.exception))
+
+ def testGoldenMessage(self, message_module):
+ # Proto3 doesn't have the "default_foo" members or foreign enums,
+ # and doesn't preserve unknown fields, so for proto3 we use a golden
+ # message that doesn't have these fields set.
+ if message_module is unittest_pb2:
+ golden_data = test_util.GoldenFileData(
+ 'golden_message_oneof_implemented')
+ else:
+ golden_data = test_util.GoldenFileData('golden_message_proto3')
- def testGoldenExtensions(self):
- golden_data = test_util.GoldenFileData('golden_message')
- golden_message = unittest_pb2.TestAllExtensions()
+ golden_message = message_module.TestAllTypes()
golden_message.ParseFromString(golden_data)
- all_set = unittest_pb2.TestAllExtensions()
- test_util.SetAllExtensions(all_set)
- self.assertEqual(all_set, golden_message)
+ if message_module is unittest_pb2:
+ test_util.ExpectAllFieldsSet(self, golden_message)
self.assertEqual(golden_data, golden_message.SerializeToString())
golden_copy = copy.deepcopy(golden_message)
self.assertEqual(golden_data, golden_copy.SerializeToString())
- def testGoldenPackedMessage(self):
+ def testGoldenPackedMessage(self, message_module):
golden_data = test_util.GoldenFileData('golden_packed_fields_message')
- golden_message = unittest_pb2.TestPackedTypes()
+ golden_message = message_module.TestPackedTypes()
golden_message.ParseFromString(golden_data)
- all_set = unittest_pb2.TestPackedTypes()
+ all_set = message_module.TestPackedTypes()
test_util.SetAllPackedFields(all_set)
self.assertEqual(all_set, golden_message)
self.assertEqual(golden_data, all_set.SerializeToString())
golden_copy = copy.deepcopy(golden_message)
self.assertEqual(golden_data, golden_copy.SerializeToString())
- def testGoldenPackedExtensions(self):
- golden_data = test_util.GoldenFileData('golden_packed_fields_message')
- golden_message = unittest_pb2.TestPackedExtensions()
- golden_message.ParseFromString(golden_data)
- all_set = unittest_pb2.TestPackedExtensions()
- test_util.SetAllPackedExtensions(all_set)
- self.assertEqual(all_set, golden_message)
- self.assertEqual(golden_data, all_set.SerializeToString())
- golden_copy = copy.deepcopy(golden_message)
- self.assertEqual(golden_data, golden_copy.SerializeToString())
-
- def testPickleSupport(self):
+ def testPickleSupport(self, message_module):
golden_data = test_util.GoldenFileData('golden_message')
- golden_message = unittest_pb2.TestAllTypes()
+ golden_message = message_module.TestAllTypes()
golden_message.ParseFromString(golden_data)
pickled_message = pickle.dumps(golden_message)
unpickled_message = pickle.loads(pickled_message)
self.assertEqual(unpickled_message, golden_message)
+ def testPositiveInfinity(self, message_module):
+ if message_module is unittest_pb2:
+ golden_data = (b'\x5D\x00\x00\x80\x7F'
+ b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
+ b'\xCD\x02\x00\x00\x80\x7F'
+ b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F')
+ else:
+ golden_data = (b'\x5D\x00\x00\x80\x7F'
+ b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
+ b'\xCA\x02\x04\x00\x00\x80\x7F'
+ b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\x7F')
- def testPickleIncompleteProto(self):
- golden_message = unittest_pb2.TestRequired(a=1)
- pickled_message = pickle.dumps(golden_message)
-
- unpickled_message = pickle.loads(pickled_message)
- self.assertEqual(unpickled_message, golden_message)
- self.assertEqual(unpickled_message.a, 1)
- # This is still an incomplete proto - so serializing should fail
- self.assertRaises(message.EncodeError, unpickled_message.SerializeToString)
-
- def testPositiveInfinity(self):
- golden_data = (b'\x5D\x00\x00\x80\x7F'
- b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
- b'\xCD\x02\x00\x00\x80\x7F'
- b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F')
- golden_message = unittest_pb2.TestAllTypes()
+ golden_message = message_module.TestAllTypes()
golden_message.ParseFromString(golden_data)
self.assertTrue(IsPosInf(golden_message.optional_float))
self.assertTrue(IsPosInf(golden_message.optional_double))
@@ -158,12 +145,19 @@ class MessageTest(unittest.TestCase):
self.assertTrue(IsPosInf(golden_message.repeated_double[0]))
self.assertEqual(golden_data, golden_message.SerializeToString())
- def testNegativeInfinity(self):
- golden_data = (b'\x5D\x00\x00\x80\xFF'
- b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
- b'\xCD\x02\x00\x00\x80\xFF'
- b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF')
- golden_message = unittest_pb2.TestAllTypes()
+ def testNegativeInfinity(self, message_module):
+ if message_module is unittest_pb2:
+ golden_data = (b'\x5D\x00\x00\x80\xFF'
+ b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
+ b'\xCD\x02\x00\x00\x80\xFF'
+ b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF')
+ else:
+ golden_data = (b'\x5D\x00\x00\x80\xFF'
+ b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
+ b'\xCA\x02\x04\x00\x00\x80\xFF'
+ b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\xFF')
+
+ golden_message = message_module.TestAllTypes()
golden_message.ParseFromString(golden_data)
self.assertTrue(IsNegInf(golden_message.optional_float))
self.assertTrue(IsNegInf(golden_message.optional_double))
@@ -171,12 +165,12 @@ class MessageTest(unittest.TestCase):
self.assertTrue(IsNegInf(golden_message.repeated_double[0]))
self.assertEqual(golden_data, golden_message.SerializeToString())
- def testNotANumber(self):
+ def testNotANumber(self, message_module):
golden_data = (b'\x5D\x00\x00\xC0\x7F'
b'\x61\x00\x00\x00\x00\x00\x00\xF8\x7F'
b'\xCD\x02\x00\x00\xC0\x7F'
b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F')
- golden_message = unittest_pb2.TestAllTypes()
+ golden_message = message_module.TestAllTypes()
golden_message.ParseFromString(golden_data)
self.assertTrue(isnan(golden_message.optional_float))
self.assertTrue(isnan(golden_message.optional_double))
@@ -188,47 +182,47 @@ class MessageTest(unittest.TestCase):
# verify the serialized string can be converted into a correctly
# behaving protocol buffer.
serialized = golden_message.SerializeToString()
- message = unittest_pb2.TestAllTypes()
+ message = message_module.TestAllTypes()
message.ParseFromString(serialized)
self.assertTrue(isnan(message.optional_float))
self.assertTrue(isnan(message.optional_double))
self.assertTrue(isnan(message.repeated_float[0]))
self.assertTrue(isnan(message.repeated_double[0]))
- def testPositiveInfinityPacked(self):
+ def testPositiveInfinityPacked(self, message_module):
golden_data = (b'\xA2\x06\x04\x00\x00\x80\x7F'
b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F')
- golden_message = unittest_pb2.TestPackedTypes()
+ golden_message = message_module.TestPackedTypes()
golden_message.ParseFromString(golden_data)
self.assertTrue(IsPosInf(golden_message.packed_float[0]))
self.assertTrue(IsPosInf(golden_message.packed_double[0]))
self.assertEqual(golden_data, golden_message.SerializeToString())
- def testNegativeInfinityPacked(self):
+ def testNegativeInfinityPacked(self, message_module):
golden_data = (b'\xA2\x06\x04\x00\x00\x80\xFF'
b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\xFF')
- golden_message = unittest_pb2.TestPackedTypes()
+ golden_message = message_module.TestPackedTypes()
golden_message.ParseFromString(golden_data)
self.assertTrue(IsNegInf(golden_message.packed_float[0]))
self.assertTrue(IsNegInf(golden_message.packed_double[0]))
self.assertEqual(golden_data, golden_message.SerializeToString())
- def testNotANumberPacked(self):
+ def testNotANumberPacked(self, message_module):
golden_data = (b'\xA2\x06\x04\x00\x00\xC0\x7F'
b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF8\x7F')
- golden_message = unittest_pb2.TestPackedTypes()
+ golden_message = message_module.TestPackedTypes()
golden_message.ParseFromString(golden_data)
self.assertTrue(isnan(golden_message.packed_float[0]))
self.assertTrue(isnan(golden_message.packed_double[0]))
serialized = golden_message.SerializeToString()
- message = unittest_pb2.TestPackedTypes()
+ message = message_module.TestPackedTypes()
message.ParseFromString(serialized)
self.assertTrue(isnan(message.packed_float[0]))
self.assertTrue(isnan(message.packed_double[0]))
- def testExtremeFloatValues(self):
- message = unittest_pb2.TestAllTypes()
+ def testExtremeFloatValues(self, message_module):
+ message = message_module.TestAllTypes()
# Most positive exponent, no significand bits set.
kMostPosExponentNoSigBits = math.pow(2, 127)
@@ -272,8 +266,8 @@ class MessageTest(unittest.TestCase):
message.ParseFromString(message.SerializeToString())
self.assertTrue(message.optional_float == -kMostNegExponentOneSigBit)
- def testExtremeDoubleValues(self):
- message = unittest_pb2.TestAllTypes()
+ def testExtremeDoubleValues(self, message_module):
+ message = message_module.TestAllTypes()
# Most positive exponent, no significand bits set.
kMostPosExponentNoSigBits = math.pow(2, 1023)
@@ -317,43 +311,43 @@ class MessageTest(unittest.TestCase):
message.ParseFromString(message.SerializeToString())
self.assertTrue(message.optional_double == -kMostNegExponentOneSigBit)
- def testFloatPrinting(self):
- message = unittest_pb2.TestAllTypes()
+ def testFloatPrinting(self, message_module):
+ message = message_module.TestAllTypes()
message.optional_float = 2.0
self.assertEqual(str(message), 'optional_float: 2.0\n')
- def testHighPrecisionFloatPrinting(self):
- message = unittest_pb2.TestAllTypes()
+ def testHighPrecisionFloatPrinting(self, message_module):
+ message = message_module.TestAllTypes()
message.optional_double = 0.12345678912345678
if sys.version_info.major >= 3:
self.assertEqual(str(message), 'optional_double: 0.12345678912345678\n')
else:
self.assertEqual(str(message), 'optional_double: 0.123456789123\n')
- def testUnknownFieldPrinting(self):
- populated = unittest_pb2.TestAllTypes()
+ def testUnknownFieldPrinting(self, message_module):
+ populated = message_module.TestAllTypes()
test_util.SetAllNonLazyFields(populated)
- empty = unittest_pb2.TestEmptyMessage()
+ empty = message_module.TestEmptyMessage()
empty.ParseFromString(populated.SerializeToString())
self.assertEqual(str(empty), '')
- def testRepeatedNestedFieldIteration(self):
- msg = unittest_pb2.TestAllTypes()
+ def testRepeatedNestedFieldIteration(self, message_module):
+ msg = message_module.TestAllTypes()
msg.repeated_nested_message.add(bb=1)
msg.repeated_nested_message.add(bb=2)
msg.repeated_nested_message.add(bb=3)
msg.repeated_nested_message.add(bb=4)
self.assertEqual([1, 2, 3, 4],
- [m.bb for m in msg.repeated_nested_message])
+ [m.bb for m in msg.repeated_nested_message])
self.assertEqual([4, 3, 2, 1],
- [m.bb for m in reversed(msg.repeated_nested_message)])
+ [m.bb for m in reversed(msg.repeated_nested_message)])
self.assertEqual([4, 3, 2, 1],
- [m.bb for m in msg.repeated_nested_message[::-1]])
+ [m.bb for m in msg.repeated_nested_message[::-1]])
- def testSortingRepeatedScalarFieldsDefaultComparator(self):
+ def testSortingRepeatedScalarFieldsDefaultComparator(self, message_module):
"""Check some different types with the default comparator."""
- message = unittest_pb2.TestAllTypes()
+ message = message_module.TestAllTypes()
# TODO(mattp): would testing more scalar types strengthen test?
message.repeated_int32.append(1)
@@ -388,9 +382,9 @@ class MessageTest(unittest.TestCase):
self.assertEqual(message.repeated_bytes[1], b'b')
self.assertEqual(message.repeated_bytes[2], b'c')
- def testSortingRepeatedScalarFieldsCustomComparator(self):
+ def testSortingRepeatedScalarFieldsCustomComparator(self, message_module):
"""Check some different types with custom comparator."""
- message = unittest_pb2.TestAllTypes()
+ message = message_module.TestAllTypes()
message.repeated_int32.append(-3)
message.repeated_int32.append(-2)
@@ -408,9 +402,9 @@ class MessageTest(unittest.TestCase):
self.assertEqual(message.repeated_string[1], 'bb')
self.assertEqual(message.repeated_string[2], 'aaa')
- def testSortingRepeatedCompositeFieldsCustomComparator(self):
+ def testSortingRepeatedCompositeFieldsCustomComparator(self, message_module):
"""Check passing a custom comparator to sort a repeated composite field."""
- message = unittest_pb2.TestAllTypes()
+ message = message_module.TestAllTypes()
message.repeated_nested_message.add().bb = 1
message.repeated_nested_message.add().bb = 3
@@ -426,9 +420,9 @@ class MessageTest(unittest.TestCase):
self.assertEqual(message.repeated_nested_message[4].bb, 5)
self.assertEqual(message.repeated_nested_message[5].bb, 6)
- def testRepeatedCompositeFieldSortArguments(self):
+ def testRepeatedCompositeFieldSortArguments(self, message_module):
"""Check sorting a repeated composite field using list.sort() arguments."""
- message = unittest_pb2.TestAllTypes()
+ message = message_module.TestAllTypes()
get_bb = operator.attrgetter('bb')
cmp_bb = lambda a, b: cmp(a.bb, b.bb)
@@ -452,9 +446,9 @@ class MessageTest(unittest.TestCase):
self.assertEqual([k.bb for k in message.repeated_nested_message],
[6, 5, 4, 3, 2, 1])
- def testRepeatedScalarFieldSortArguments(self):
+ def testRepeatedScalarFieldSortArguments(self, message_module):
"""Check sorting a scalar field using list.sort() arguments."""
- message = unittest_pb2.TestAllTypes()
+ message = message_module.TestAllTypes()
message.repeated_int32.append(-3)
message.repeated_int32.append(-2)
@@ -484,9 +478,9 @@ class MessageTest(unittest.TestCase):
message.repeated_string.sort(cmp=len_cmp, reverse=True)
self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c'])
- def testRepeatedFieldsComparable(self):
- m1 = unittest_pb2.TestAllTypes()
- m2 = unittest_pb2.TestAllTypes()
+ def testRepeatedFieldsComparable(self, message_module):
+ m1 = message_module.TestAllTypes()
+ m2 = message_module.TestAllTypes()
m1.repeated_int32.append(0)
m1.repeated_int32.append(1)
m1.repeated_int32.append(2)
@@ -519,55 +513,6 @@ class MessageTest(unittest.TestCase):
# TODO(anuraag): Implement extensiondict comparison in C++ and then add test
- def testParsingMerge(self):
- """Check the merge behavior when a required or optional field appears
- multiple times in the input."""
- messages = [
- unittest_pb2.TestAllTypes(),
- unittest_pb2.TestAllTypes(),
- unittest_pb2.TestAllTypes() ]
- messages[0].optional_int32 = 1
- messages[1].optional_int64 = 2
- messages[2].optional_int32 = 3
- messages[2].optional_string = 'hello'
-
- merged_message = unittest_pb2.TestAllTypes()
- merged_message.optional_int32 = 3
- merged_message.optional_int64 = 2
- merged_message.optional_string = 'hello'
-
- generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator()
- generator.field1.extend(messages)
- generator.field2.extend(messages)
- generator.field3.extend(messages)
- generator.ext1.extend(messages)
- generator.ext2.extend(messages)
- generator.group1.add().field1.MergeFrom(messages[0])
- generator.group1.add().field1.MergeFrom(messages[1])
- generator.group1.add().field1.MergeFrom(messages[2])
- generator.group2.add().field1.MergeFrom(messages[0])
- generator.group2.add().field1.MergeFrom(messages[1])
- generator.group2.add().field1.MergeFrom(messages[2])
-
- data = generator.SerializeToString()
- parsing_merge = unittest_pb2.TestParsingMerge()
- parsing_merge.ParseFromString(data)
-
- # Required and optional fields should be merged.
- self.assertEqual(parsing_merge.required_all_types, merged_message)
- self.assertEqual(parsing_merge.optional_all_types, merged_message)
- self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types,
- merged_message)
- self.assertEqual(parsing_merge.Extensions[
- unittest_pb2.TestParsingMerge.optional_ext],
- merged_message)
-
- # Repeated fields should not be merged.
- self.assertEqual(len(parsing_merge.repeated_all_types), 3)
- self.assertEqual(len(parsing_merge.repeatedgroup), 3)
- self.assertEqual(len(parsing_merge.Extensions[
- unittest_pb2.TestParsingMerge.repeated_ext]), 3)
-
def ensureNestedMessageExists(self, msg, attribute):
"""Make sure that a nested message object exists.
@@ -577,12 +522,28 @@ class MessageTest(unittest.TestCase):
getattr(msg, attribute)
self.assertFalse(msg.HasField(attribute))
- def testOneofGetCaseNonexistingField(self):
- m = unittest_pb2.TestAllTypes()
+ def testOneofGetCaseNonexistingField(self, message_module):
+ m = message_module.TestAllTypes()
self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field')
- def testOneofSemantics(self):
- m = unittest_pb2.TestAllTypes()
+ def testOneofDefaultValues(self, message_module):
+ m = message_module.TestAllTypes()
+ self.assertIs(None, m.WhichOneof('oneof_field'))
+ self.assertFalse(m.HasField('oneof_uint32'))
+
+ # Oneof is set even when setting it to a default value.
+ m.oneof_uint32 = 0
+ self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
+ self.assertTrue(m.HasField('oneof_uint32'))
+ self.assertFalse(m.HasField('oneof_string'))
+
+ m.oneof_string = ""
+ self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
+ self.assertTrue(m.HasField('oneof_string'))
+ self.assertFalse(m.HasField('oneof_uint32'))
+
+ def testOneofSemantics(self, message_module):
+ m = message_module.TestAllTypes()
self.assertIs(None, m.WhichOneof('oneof_field'))
m.oneof_uint32 = 11
@@ -604,96 +565,1024 @@ class MessageTest(unittest.TestCase):
self.assertFalse(m.HasField('oneof_nested_message'))
self.assertTrue(m.HasField('oneof_bytes'))
- def testOneofCompositeFieldReadAccess(self):
- m = unittest_pb2.TestAllTypes()
+ def testOneofCompositeFieldReadAccess(self, message_module):
+ m = message_module.TestAllTypes()
m.oneof_uint32 = 11
self.ensureNestedMessageExists(m, 'oneof_nested_message')
self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
self.assertEqual(11, m.oneof_uint32)
- def testOneofHasField(self):
- m = unittest_pb2.TestAllTypes()
- self.assertFalse(m.HasField('oneof_field'))
+ def testOneofWhichOneof(self, message_module):
+ m = message_module.TestAllTypes()
+ self.assertIs(None, m.WhichOneof('oneof_field'))
+ if message_module is unittest_pb2:
+ self.assertFalse(m.HasField('oneof_field'))
+
m.oneof_uint32 = 11
- self.assertTrue(m.HasField('oneof_field'))
+ self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
+ if message_module is unittest_pb2:
+ self.assertTrue(m.HasField('oneof_field'))
+
m.oneof_bytes = b'bb'
- self.assertTrue(m.HasField('oneof_field'))
+ self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
+
m.ClearField('oneof_bytes')
- self.assertFalse(m.HasField('oneof_field'))
+ self.assertIs(None, m.WhichOneof('oneof_field'))
+ if message_module is unittest_pb2:
+ self.assertFalse(m.HasField('oneof_field'))
- def testOneofClearField(self):
- m = unittest_pb2.TestAllTypes()
+ def testOneofClearField(self, message_module):
+ m = message_module.TestAllTypes()
m.oneof_uint32 = 11
m.ClearField('oneof_field')
- self.assertFalse(m.HasField('oneof_field'))
+ if message_module is unittest_pb2:
+ self.assertFalse(m.HasField('oneof_field'))
self.assertFalse(m.HasField('oneof_uint32'))
self.assertIs(None, m.WhichOneof('oneof_field'))
- def testOneofClearSetField(self):
- m = unittest_pb2.TestAllTypes()
+ def testOneofClearSetField(self, message_module):
+ m = message_module.TestAllTypes()
m.oneof_uint32 = 11
m.ClearField('oneof_uint32')
- self.assertFalse(m.HasField('oneof_field'))
+ if message_module is unittest_pb2:
+ self.assertFalse(m.HasField('oneof_field'))
self.assertFalse(m.HasField('oneof_uint32'))
self.assertIs(None, m.WhichOneof('oneof_field'))
- def testOneofClearUnsetField(self):
- m = unittest_pb2.TestAllTypes()
+ def testOneofClearUnsetField(self, message_module):
+ m = message_module.TestAllTypes()
m.oneof_uint32 = 11
self.ensureNestedMessageExists(m, 'oneof_nested_message')
m.ClearField('oneof_nested_message')
self.assertEqual(11, m.oneof_uint32)
- self.assertTrue(m.HasField('oneof_field'))
+ if message_module is unittest_pb2:
+ self.assertTrue(m.HasField('oneof_field'))
self.assertTrue(m.HasField('oneof_uint32'))
self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
- def testOneofDeserialize(self):
- m = unittest_pb2.TestAllTypes()
+ def testOneofDeserialize(self, message_module):
+ m = message_module.TestAllTypes()
m.oneof_uint32 = 11
- m2 = unittest_pb2.TestAllTypes()
+ m2 = message_module.TestAllTypes()
m2.ParseFromString(m.SerializeToString())
self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
- def testOneofCopyFrom(self):
- m = unittest_pb2.TestAllTypes()
+ def testOneofCopyFrom(self, message_module):
+ m = message_module.TestAllTypes()
m.oneof_uint32 = 11
- m2 = unittest_pb2.TestAllTypes()
+ m2 = message_module.TestAllTypes()
m2.CopyFrom(m)
self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
- def testOneofNestedMergeFrom(self):
- m = unittest_pb2.NestedTestAllTypes()
+ def testOneofNestedMergeFrom(self, message_module):
+ m = message_module.NestedTestAllTypes()
m.payload.oneof_uint32 = 11
- m2 = unittest_pb2.NestedTestAllTypes()
+ m2 = message_module.NestedTestAllTypes()
m2.payload.oneof_bytes = b'bb'
m2.child.payload.oneof_bytes = b'bb'
m2.MergeFrom(m)
self.assertEqual('oneof_uint32', m2.payload.WhichOneof('oneof_field'))
self.assertEqual('oneof_bytes', m2.child.payload.WhichOneof('oneof_field'))
- def testOneofClear(self):
- m = unittest_pb2.TestAllTypes()
+ def testOneofMessageMergeFrom(self, message_module):
+ m = message_module.NestedTestAllTypes()
+ m.payload.oneof_nested_message.bb = 11
+ m.child.payload.oneof_nested_message.bb = 12
+ m2 = message_module.NestedTestAllTypes()
+ m2.payload.oneof_uint32 = 13
+ m2.MergeFrom(m)
+ self.assertEqual('oneof_nested_message',
+ m2.payload.WhichOneof('oneof_field'))
+ self.assertEqual('oneof_nested_message',
+ m2.child.payload.WhichOneof('oneof_field'))
+
+ def testOneofNestedMessageInit(self, message_module):
+ m = message_module.TestAllTypes(
+ oneof_nested_message=message_module.TestAllTypes.NestedMessage())
+ self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field'))
+
+ def testOneofClear(self, message_module):
+ m = message_module.TestAllTypes()
m.oneof_uint32 = 11
m.Clear()
self.assertIsNone(m.WhichOneof('oneof_field'))
m.oneof_bytes = b'bb'
- self.assertTrue(m.HasField('oneof_field'))
+ self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
+ def testAssignByteStringToUnicodeField(self, message_module):
+ """Assigning a byte string to a string field should result
+ in the value being converted to a Unicode string."""
+ m = message_module.TestAllTypes()
+ m.optional_string = str('')
+ self.assertTrue(isinstance(m.optional_string, unicode))
- def testSortEmptyRepeatedCompositeContainer(self):
+# TODO(haberman): why are these tests Google-internal only?
+
+ def testLongValuedSlice(self, message_module):
+ """It should be possible to use long-valued indicies in slices
+
+ This didn't used to work in the v2 C++ implementation.
+ """
+ m = message_module.TestAllTypes()
+
+ # Repeated scalar
+ m.repeated_int32.append(1)
+ sl = m.repeated_int32[long(0):long(len(m.repeated_int32))]
+ self.assertEqual(len(m.repeated_int32), len(sl))
+
+ # Repeated composite
+ m.repeated_nested_message.add().bb = 3
+ sl = m.repeated_nested_message[long(0):long(len(m.repeated_nested_message))]
+ self.assertEqual(len(m.repeated_nested_message), len(sl))
+
+ def testExtendShouldNotSwallowExceptions(self, message_module):
+ """This didn't use to work in the v2 C++ implementation."""
+ m = message_module.TestAllTypes()
+ with self.assertRaises(NameError) as _:
+ m.repeated_int32.extend(a for i in range(10)) # pylint: disable=undefined-variable
+ with self.assertRaises(NameError) as _:
+ m.repeated_nested_enum.extend(
+ a for i in range(10)) # pylint: disable=undefined-variable
+
+ FALSY_VALUES = [None, False, 0, 0.0, b'', u'', bytearray(), [], {}, set()]
+
+ def testExtendInt32WithNothing(self, message_module):
+ """Test no-ops extending repeated int32 fields."""
+ m = message_module.TestAllTypes()
+ self.assertSequenceEqual([], m.repeated_int32)
+
+ # TODO(ptucker): Deprecate this behavior. b/18413862
+ for falsy_value in MessageTest.FALSY_VALUES:
+ m.repeated_int32.extend(falsy_value)
+ self.assertSequenceEqual([], m.repeated_int32)
+
+ m.repeated_int32.extend([])
+ self.assertSequenceEqual([], m.repeated_int32)
+
+ def testExtendFloatWithNothing(self, message_module):
+ """Test no-ops extending repeated float fields."""
+ m = message_module.TestAllTypes()
+ self.assertSequenceEqual([], m.repeated_float)
+
+ # TODO(ptucker): Deprecate this behavior. b/18413862
+ for falsy_value in MessageTest.FALSY_VALUES:
+ m.repeated_float.extend(falsy_value)
+ self.assertSequenceEqual([], m.repeated_float)
+
+ m.repeated_float.extend([])
+ self.assertSequenceEqual([], m.repeated_float)
+
+ def testExtendStringWithNothing(self, message_module):
+ """Test no-ops extending repeated string fields."""
+ m = message_module.TestAllTypes()
+ self.assertSequenceEqual([], m.repeated_string)
+
+ # TODO(ptucker): Deprecate this behavior. b/18413862
+ for falsy_value in MessageTest.FALSY_VALUES:
+ m.repeated_string.extend(falsy_value)
+ self.assertSequenceEqual([], m.repeated_string)
+
+ m.repeated_string.extend([])
+ self.assertSequenceEqual([], m.repeated_string)
+
+ def testExtendInt32WithPythonList(self, message_module):
+ """Test extending repeated int32 fields with python lists."""
+ m = message_module.TestAllTypes()
+ self.assertSequenceEqual([], m.repeated_int32)
+ m.repeated_int32.extend([0])
+ self.assertSequenceEqual([0], m.repeated_int32)
+ m.repeated_int32.extend([1, 2])
+ self.assertSequenceEqual([0, 1, 2], m.repeated_int32)
+ m.repeated_int32.extend([3, 4])
+ self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32)
+
+ def testExtendFloatWithPythonList(self, message_module):
+ """Test extending repeated float fields with python lists."""
+ m = message_module.TestAllTypes()
+ self.assertSequenceEqual([], m.repeated_float)
+ m.repeated_float.extend([0.0])
+ self.assertSequenceEqual([0.0], m.repeated_float)
+ m.repeated_float.extend([1.0, 2.0])
+ self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float)
+ m.repeated_float.extend([3.0, 4.0])
+ self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float)
+
+ def testExtendStringWithPythonList(self, message_module):
+ """Test extending repeated string fields with python lists."""
+ m = message_module.TestAllTypes()
+ self.assertSequenceEqual([], m.repeated_string)
+ m.repeated_string.extend([''])
+ self.assertSequenceEqual([''], m.repeated_string)
+ m.repeated_string.extend(['11', '22'])
+ self.assertSequenceEqual(['', '11', '22'], m.repeated_string)
+ m.repeated_string.extend(['33', '44'])
+ self.assertSequenceEqual(['', '11', '22', '33', '44'], m.repeated_string)
+
+ def testExtendStringWithString(self, message_module):
+ """Test extending repeated string fields with characters from a string."""
+ m = message_module.TestAllTypes()
+ self.assertSequenceEqual([], m.repeated_string)
+ m.repeated_string.extend('abc')
+ self.assertSequenceEqual(['a', 'b', 'c'], m.repeated_string)
+
+ class TestIterable(object):
+ """This iterable object mimics the behavior of numpy.array.
+
+ __nonzero__ fails for length > 1, and returns bool(item[0]) for length == 1.
+
+ """
+
+ def __init__(self, values=None):
+ self._list = values or []
+
+ def __nonzero__(self):
+ size = len(self._list)
+ if size == 0:
+ return False
+ if size == 1:
+ return bool(self._list[0])
+ raise ValueError('Truth value is ambiguous.')
+
+ def __len__(self):
+ return len(self._list)
+
+ def __iter__(self):
+ return self._list.__iter__()
+
+ def testExtendInt32WithIterable(self, message_module):
+ """Test extending repeated int32 fields with iterable."""
+ m = message_module.TestAllTypes()
+ self.assertSequenceEqual([], m.repeated_int32)
+ m.repeated_int32.extend(MessageTest.TestIterable([]))
+ self.assertSequenceEqual([], m.repeated_int32)
+ m.repeated_int32.extend(MessageTest.TestIterable([0]))
+ self.assertSequenceEqual([0], m.repeated_int32)
+ m.repeated_int32.extend(MessageTest.TestIterable([1, 2]))
+ self.assertSequenceEqual([0, 1, 2], m.repeated_int32)
+ m.repeated_int32.extend(MessageTest.TestIterable([3, 4]))
+ self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32)
+
+ def testExtendFloatWithIterable(self, message_module):
+ """Test extending repeated float fields with iterable."""
+ m = message_module.TestAllTypes()
+ self.assertSequenceEqual([], m.repeated_float)
+ m.repeated_float.extend(MessageTest.TestIterable([]))
+ self.assertSequenceEqual([], m.repeated_float)
+ m.repeated_float.extend(MessageTest.TestIterable([0.0]))
+ self.assertSequenceEqual([0.0], m.repeated_float)
+ m.repeated_float.extend(MessageTest.TestIterable([1.0, 2.0]))
+ self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float)
+ m.repeated_float.extend(MessageTest.TestIterable([3.0, 4.0]))
+ self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float)
+
+ def testExtendStringWithIterable(self, message_module):
+ """Test extending repeated string fields with iterable."""
+ m = message_module.TestAllTypes()
+ self.assertSequenceEqual([], m.repeated_string)
+ m.repeated_string.extend(MessageTest.TestIterable([]))
+ self.assertSequenceEqual([], m.repeated_string)
+ m.repeated_string.extend(MessageTest.TestIterable(['']))
+ self.assertSequenceEqual([''], m.repeated_string)
+ m.repeated_string.extend(MessageTest.TestIterable(['1', '2']))
+ self.assertSequenceEqual(['', '1', '2'], m.repeated_string)
+ m.repeated_string.extend(MessageTest.TestIterable(['3', '4']))
+ self.assertSequenceEqual(['', '1', '2', '3', '4'], m.repeated_string)
+
+ def testPickleRepeatedScalarContainer(self, message_module):
+ # TODO(tibell): The pure-Python implementation support pickling of
+ # scalar containers in *some* cases. For now the cpp2 version
+ # throws an exception to avoid a segfault. Investigate if we
+ # want to support pickling of these fields.
+ #
+ # For more information see: https://b2.corp.google.com/u/0/issues/18677897
+ if (api_implementation.Type() != 'cpp' or
+ api_implementation.Version() == 2):
+ return
+ m = message_module.TestAllTypes()
+ with self.assertRaises(pickle.PickleError) as _:
+ pickle.dumps(m.repeated_int32, pickle.HIGHEST_PROTOCOL)
+
+
+ def testSortEmptyRepeatedCompositeContainer(self, message_module):
"""Exercise a scenario that has led to segfaults in the past.
"""
- m = unittest_pb2.TestAllTypes()
+ m = message_module.TestAllTypes()
m.repeated_nested_message.sort()
- def testHasFieldOnRepeatedField(self):
+ def testHasFieldOnRepeatedField(self, message_module):
"""Using HasField on a repeated field should raise an exception.
"""
- m = unittest_pb2.TestAllTypes()
+ m = message_module.TestAllTypes()
with self.assertRaises(ValueError) as _:
m.HasField('repeated_int32')
+ def testRepeatedScalarFieldPop(self, message_module):
+ m = message_module.TestAllTypes()
+ with self.assertRaises(IndexError) as _:
+ m.repeated_int32.pop()
+ m.repeated_int32.extend(range(5))
+ self.assertEqual(4, m.repeated_int32.pop())
+ self.assertEqual(0, m.repeated_int32.pop(0))
+ self.assertEqual(2, m.repeated_int32.pop(1))
+ self.assertEqual([1, 3], m.repeated_int32)
+
+ def testRepeatedCompositeFieldPop(self, message_module):
+ m = message_module.TestAllTypes()
+ with self.assertRaises(IndexError) as _:
+ m.repeated_nested_message.pop()
+ for i in range(5):
+ n = m.repeated_nested_message.add()
+ n.bb = i
+ self.assertEqual(4, m.repeated_nested_message.pop().bb)
+ self.assertEqual(0, m.repeated_nested_message.pop(0).bb)
+ self.assertEqual(2, m.repeated_nested_message.pop(1).bb)
+ self.assertEqual([1, 3], [n.bb for n in m.repeated_nested_message])
+
+
+# Class to test proto2-only features (required, extensions, etc.)
+class Proto2Test(unittest.TestCase):
+
+ def testFieldPresence(self):
+ message = unittest_pb2.TestAllTypes()
+
+ self.assertFalse(message.HasField("optional_int32"))
+ self.assertFalse(message.HasField("optional_bool"))
+ self.assertFalse(message.HasField("optional_nested_message"))
+
+ with self.assertRaises(ValueError):
+ message.HasField("field_doesnt_exist")
+
+ with self.assertRaises(ValueError):
+ message.HasField("repeated_int32")
+ with self.assertRaises(ValueError):
+ message.HasField("repeated_nested_message")
+
+ self.assertEqual(0, message.optional_int32)
+ self.assertEqual(False, message.optional_bool)
+ self.assertEqual(0, message.optional_nested_message.bb)
+
+ # Fields are set even when setting the values to default values.
+ message.optional_int32 = 0
+ message.optional_bool = False
+ message.optional_nested_message.bb = 0
+ self.assertTrue(message.HasField("optional_int32"))
+ self.assertTrue(message.HasField("optional_bool"))
+ self.assertTrue(message.HasField("optional_nested_message"))
+
+ # Set the fields to non-default values.
+ message.optional_int32 = 5
+ message.optional_bool = True
+ message.optional_nested_message.bb = 15
+
+ self.assertTrue(message.HasField("optional_int32"))
+ self.assertTrue(message.HasField("optional_bool"))
+ self.assertTrue(message.HasField("optional_nested_message"))
+
+ # Clearing the fields unsets them and resets their value to default.
+ message.ClearField("optional_int32")
+ message.ClearField("optional_bool")
+ message.ClearField("optional_nested_message")
+
+ self.assertFalse(message.HasField("optional_int32"))
+ self.assertFalse(message.HasField("optional_bool"))
+ self.assertFalse(message.HasField("optional_nested_message"))
+ self.assertEqual(0, message.optional_int32)
+ self.assertEqual(False, message.optional_bool)
+ self.assertEqual(0, message.optional_nested_message.bb)
+
+ # TODO(tibell): The C++ implementations actually allows assignment
+ # of unknown enum values to *scalar* fields (but not repeated
+ # fields). Once checked enum fields becomes the default in the
+ # Python implementation, the C++ implementation should follow suit.
+ def testAssignInvalidEnum(self):
+ """It should not be possible to assign an invalid enum number to an
+ enum field."""
+ m = unittest_pb2.TestAllTypes()
+
+ with self.assertRaises(ValueError) as _:
+ m.optional_nested_enum = 1234567
+ self.assertRaises(ValueError, m.repeated_nested_enum.append, 1234567)
+
+ def testGoldenExtensions(self):
+ golden_data = test_util.GoldenFileData('golden_message')
+ golden_message = unittest_pb2.TestAllExtensions()
+ golden_message.ParseFromString(golden_data)
+ all_set = unittest_pb2.TestAllExtensions()
+ test_util.SetAllExtensions(all_set)
+ self.assertEqual(all_set, golden_message)
+ self.assertEqual(golden_data, golden_message.SerializeToString())
+ golden_copy = copy.deepcopy(golden_message)
+ self.assertEqual(golden_data, golden_copy.SerializeToString())
+
+ def testGoldenPackedExtensions(self):
+ golden_data = test_util.GoldenFileData('golden_packed_fields_message')
+ golden_message = unittest_pb2.TestPackedExtensions()
+ golden_message.ParseFromString(golden_data)
+ all_set = unittest_pb2.TestPackedExtensions()
+ test_util.SetAllPackedExtensions(all_set)
+ self.assertEqual(all_set, golden_message)
+ self.assertEqual(golden_data, all_set.SerializeToString())
+ golden_copy = copy.deepcopy(golden_message)
+ self.assertEqual(golden_data, golden_copy.SerializeToString())
+
+ def testPickleIncompleteProto(self):
+ golden_message = unittest_pb2.TestRequired(a=1)
+ pickled_message = pickle.dumps(golden_message)
+
+ unpickled_message = pickle.loads(pickled_message)
+ self.assertEqual(unpickled_message, golden_message)
+ self.assertEqual(unpickled_message.a, 1)
+ # This is still an incomplete proto - so serializing should fail
+ self.assertRaises(message.EncodeError, unpickled_message.SerializeToString)
+
+
+ # TODO(haberman): this isn't really a proto2-specific test except that this
+ # message has a required field in it. Should probably be factored out so
+ # that we can test the other parts with proto3.
+ def testParsingMerge(self):
+ """Check the merge behavior when a required or optional field appears
+ multiple times in the input."""
+ messages = [
+ unittest_pb2.TestAllTypes(),
+ unittest_pb2.TestAllTypes(),
+ unittest_pb2.TestAllTypes() ]
+ messages[0].optional_int32 = 1
+ messages[1].optional_int64 = 2
+ messages[2].optional_int32 = 3
+ messages[2].optional_string = 'hello'
+
+ merged_message = unittest_pb2.TestAllTypes()
+ merged_message.optional_int32 = 3
+ merged_message.optional_int64 = 2
+ merged_message.optional_string = 'hello'
+
+ generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator()
+ generator.field1.extend(messages)
+ generator.field2.extend(messages)
+ generator.field3.extend(messages)
+ generator.ext1.extend(messages)
+ generator.ext2.extend(messages)
+ generator.group1.add().field1.MergeFrom(messages[0])
+ generator.group1.add().field1.MergeFrom(messages[1])
+ generator.group1.add().field1.MergeFrom(messages[2])
+ generator.group2.add().field1.MergeFrom(messages[0])
+ generator.group2.add().field1.MergeFrom(messages[1])
+ generator.group2.add().field1.MergeFrom(messages[2])
+
+ data = generator.SerializeToString()
+ parsing_merge = unittest_pb2.TestParsingMerge()
+ parsing_merge.ParseFromString(data)
+
+ # Required and optional fields should be merged.
+ self.assertEqual(parsing_merge.required_all_types, merged_message)
+ self.assertEqual(parsing_merge.optional_all_types, merged_message)
+ self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types,
+ merged_message)
+ self.assertEqual(parsing_merge.Extensions[
+ unittest_pb2.TestParsingMerge.optional_ext],
+ merged_message)
+
+ # Repeated fields should not be merged.
+ self.assertEqual(len(parsing_merge.repeated_all_types), 3)
+ self.assertEqual(len(parsing_merge.repeatedgroup), 3)
+ self.assertEqual(len(parsing_merge.Extensions[
+ unittest_pb2.TestParsingMerge.repeated_ext]), 3)
+
+ def testPythonicInit(self):
+ message = unittest_pb2.TestAllTypes(
+ optional_int32=100,
+ optional_fixed32=200,
+ optional_float=300.5,
+ optional_bytes=b'x',
+ optionalgroup={'a': 400},
+ optional_nested_message={'bb': 500},
+ optional_nested_enum='BAZ',
+ repeatedgroup=[{'a': 600},
+ {'a': 700}],
+ repeated_nested_enum=['FOO', unittest_pb2.TestAllTypes.BAR],
+ default_int32=800,
+ oneof_string='y')
+ self.assertTrue(isinstance(message, unittest_pb2.TestAllTypes))
+ self.assertEqual(100, message.optional_int32)
+ self.assertEqual(200, message.optional_fixed32)
+ self.assertEqual(300.5, message.optional_float)
+ self.assertEqual(b'x', message.optional_bytes)
+ self.assertEqual(400, message.optionalgroup.a)
+ self.assertTrue(isinstance(message.optional_nested_message,
+ unittest_pb2.TestAllTypes.NestedMessage))
+ self.assertEqual(500, message.optional_nested_message.bb)
+ self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
+ message.optional_nested_enum)
+ self.assertEqual(2, len(message.repeatedgroup))
+ self.assertEqual(600, message.repeatedgroup[0].a)
+ self.assertEqual(700, message.repeatedgroup[1].a)
+ self.assertEqual(2, len(message.repeated_nested_enum))
+ self.assertEqual(unittest_pb2.TestAllTypes.FOO,
+ message.repeated_nested_enum[0])
+ self.assertEqual(unittest_pb2.TestAllTypes.BAR,
+ message.repeated_nested_enum[1])
+ self.assertEqual(800, message.default_int32)
+ self.assertEqual('y', message.oneof_string)
+ self.assertFalse(message.HasField('optional_int64'))
+ self.assertEqual(0, len(message.repeated_float))
+ self.assertEqual(42, message.default_int64)
+
+ message = unittest_pb2.TestAllTypes(optional_nested_enum=u'BAZ')
+ self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
+ message.optional_nested_enum)
+
+ with self.assertRaises(ValueError):
+ unittest_pb2.TestAllTypes(
+ optional_nested_message={'INVALID_NESTED_FIELD': 17})
+
+ with self.assertRaises(TypeError):
+ unittest_pb2.TestAllTypes(
+ optional_nested_message={'bb': 'INVALID_VALUE_TYPE'})
+
+ with self.assertRaises(ValueError):
+ unittest_pb2.TestAllTypes(optional_nested_enum='INVALID_LABEL')
+
+ with self.assertRaises(ValueError):
+ unittest_pb2.TestAllTypes(repeated_nested_enum='FOO')
+
+
+# Class to test proto3-only features/behavior (updated field presence & enums)
+class Proto3Test(unittest.TestCase):
+
+ # Utility method for comparing equality with a map.
+ def assertMapIterEquals(self, map_iter, dict_value):
+ # Avoid mutating caller's copy.
+ dict_value = dict(dict_value)
+
+ for k, v in map_iter:
+ self.assertEqual(v, dict_value[k])
+ del dict_value[k]
+
+ self.assertEqual({}, dict_value)
+
+ def testFieldPresence(self):
+ message = unittest_proto3_arena_pb2.TestAllTypes()
+
+ # We can't test presence of non-repeated, non-submessage fields.
+ with self.assertRaises(ValueError):
+ message.HasField('optional_int32')
+ with self.assertRaises(ValueError):
+ message.HasField('optional_float')
+ with self.assertRaises(ValueError):
+ message.HasField('optional_string')
+ with self.assertRaises(ValueError):
+ message.HasField('optional_bool')
+
+ # But we can still test presence of submessage fields.
+ self.assertFalse(message.HasField('optional_nested_message'))
+
+ # As with proto2, we can't test presence of fields that don't exist, or
+ # repeated fields.
+ with self.assertRaises(ValueError):
+ message.HasField('field_doesnt_exist')
+
+ with self.assertRaises(ValueError):
+ message.HasField('repeated_int32')
+ with self.assertRaises(ValueError):
+ message.HasField('repeated_nested_message')
+
+ # Fields should default to their type-specific default.
+ self.assertEqual(0, message.optional_int32)
+ self.assertEqual(0, message.optional_float)
+ self.assertEqual('', message.optional_string)
+ self.assertEqual(False, message.optional_bool)
+ self.assertEqual(0, message.optional_nested_message.bb)
+
+ # Setting a submessage should still return proper presence information.
+ message.optional_nested_message.bb = 0
+ self.assertTrue(message.HasField('optional_nested_message'))
+
+ # Set the fields to non-default values.
+ message.optional_int32 = 5
+ message.optional_float = 1.1
+ message.optional_string = 'abc'
+ message.optional_bool = True
+ message.optional_nested_message.bb = 15
+
+ # Clearing the fields unsets them and resets their value to default.
+ message.ClearField('optional_int32')
+ message.ClearField('optional_float')
+ message.ClearField('optional_string')
+ message.ClearField('optional_bool')
+ message.ClearField('optional_nested_message')
+
+ self.assertEqual(0, message.optional_int32)
+ self.assertEqual(0, message.optional_float)
+ self.assertEqual('', message.optional_string)
+ self.assertEqual(False, message.optional_bool)
+ self.assertEqual(0, message.optional_nested_message.bb)
+
+ def testAssignUnknownEnum(self):
+ """Assigning an unknown enum value is allowed and preserves the value."""
+ m = unittest_proto3_arena_pb2.TestAllTypes()
+
+ m.optional_nested_enum = 1234567
+ self.assertEqual(1234567, m.optional_nested_enum)
+ m.repeated_nested_enum.append(22334455)
+ self.assertEqual(22334455, m.repeated_nested_enum[0])
+ # Assignment is a different code path than append for the C++ impl.
+ m.repeated_nested_enum[0] = 7654321
+ self.assertEqual(7654321, m.repeated_nested_enum[0])
+ serialized = m.SerializeToString()
+
+ m2 = unittest_proto3_arena_pb2.TestAllTypes()
+ m2.ParseFromString(serialized)
+ self.assertEqual(1234567, m2.optional_nested_enum)
+ self.assertEqual(7654321, m2.repeated_nested_enum[0])
+
+ # Map isn't really a proto3-only feature. But there is no proto2 equivalent
+ # of google/protobuf/map_unittest.proto right now, so it's not easy to
+ # test both with the same test like we do for the other proto2/proto3 tests.
+ # (google/protobuf/map_protobuf_unittest.proto is very different in the set
+ # of messages and fields it contains).
+ def testScalarMapDefaults(self):
+ msg = map_unittest_pb2.TestMap()
+
+ # Scalars start out unset.
+ self.assertFalse(-123 in msg.map_int32_int32)
+ self.assertFalse(-2**33 in msg.map_int64_int64)
+ self.assertFalse(123 in msg.map_uint32_uint32)
+ self.assertFalse(2**33 in msg.map_uint64_uint64)
+ self.assertFalse('abc' in msg.map_string_string)
+ self.assertFalse(888 in msg.map_int32_enum)
+
+ # Accessing an unset key returns the default.
+ self.assertEqual(0, msg.map_int32_int32[-123])
+ self.assertEqual(0, msg.map_int64_int64[-2**33])
+ self.assertEqual(0, msg.map_uint32_uint32[123])
+ self.assertEqual(0, msg.map_uint64_uint64[2**33])
+ self.assertEqual('', msg.map_string_string['abc'])
+ self.assertEqual(0, msg.map_int32_enum[888])
+
+ # It also sets the value in the map
+ self.assertTrue(-123 in msg.map_int32_int32)
+ self.assertTrue(-2**33 in msg.map_int64_int64)
+ self.assertTrue(123 in msg.map_uint32_uint32)
+ self.assertTrue(2**33 in msg.map_uint64_uint64)
+ self.assertTrue('abc' in msg.map_string_string)
+ self.assertTrue(888 in msg.map_int32_enum)
+
+ self.assertTrue(isinstance(msg.map_string_string['abc'], unicode))
+
+ # Accessing an unset key still throws TypeError of the type of the key
+ # is incorrect.
+ with self.assertRaises(TypeError):
+ msg.map_string_string[123]
+
+ self.assertFalse(123 in msg.map_string_string)
+
+ def testMapGet(self):
+ # Need to test that get() properly returns the default, even though the dict
+ # has defaultdict-like semantics.
+ msg = map_unittest_pb2.TestMap()
+
+ self.assertIsNone(msg.map_int32_int32.get(5))
+ self.assertEquals(10, msg.map_int32_int32.get(5, 10))
+ self.assertIsNone(msg.map_int32_int32.get(5))
+
+ msg.map_int32_int32[5] = 15
+ self.assertEquals(15, msg.map_int32_int32.get(5))
+
+ self.assertIsNone(msg.map_int32_foreign_message.get(5))
+ self.assertEquals(10, msg.map_int32_foreign_message.get(5, 10))
+
+ submsg = msg.map_int32_foreign_message[5]
+ self.assertIs(submsg, msg.map_int32_foreign_message.get(5))
+
+ def testScalarMap(self):
+ msg = map_unittest_pb2.TestMap()
+
+ self.assertEqual(0, len(msg.map_int32_int32))
+ self.assertFalse(5 in msg.map_int32_int32)
+
+ msg.map_int32_int32[-123] = -456
+ msg.map_int64_int64[-2**33] = -2**34
+ msg.map_uint32_uint32[123] = 456
+ msg.map_uint64_uint64[2**33] = 2**34
+ msg.map_string_string['abc'] = '123'
+ msg.map_int32_enum[888] = 2
+
+ self.assertEqual([], msg.FindInitializationErrors())
+
+ self.assertEqual(1, len(msg.map_string_string))
+
+ # Bad key.
+ with self.assertRaises(TypeError):
+ msg.map_string_string[123] = '123'
+
+ # Verify that trying to assign a bad key doesn't actually add a member to
+ # the map.
+ self.assertEqual(1, len(msg.map_string_string))
+
+ # Bad value.
+ with self.assertRaises(TypeError):
+ msg.map_string_string['123'] = 123
+
+ serialized = msg.SerializeToString()
+ msg2 = map_unittest_pb2.TestMap()
+ msg2.ParseFromString(serialized)
+
+ # Bad key.
+ with self.assertRaises(TypeError):
+ msg2.map_string_string[123] = '123'
+
+ # Bad value.
+ with self.assertRaises(TypeError):
+ msg2.map_string_string['123'] = 123
+
+ self.assertEqual(-456, msg2.map_int32_int32[-123])
+ self.assertEqual(-2**34, msg2.map_int64_int64[-2**33])
+ self.assertEqual(456, msg2.map_uint32_uint32[123])
+ self.assertEqual(2**34, msg2.map_uint64_uint64[2**33])
+ self.assertEqual('123', msg2.map_string_string['abc'])
+ self.assertEqual(2, msg2.map_int32_enum[888])
+
+ def testStringUnicodeConversionInMap(self):
+ msg = map_unittest_pb2.TestMap()
+
+ unicode_obj = u'\u1234'
+ bytes_obj = unicode_obj.encode('utf8')
+
+ msg.map_string_string[bytes_obj] = bytes_obj
+
+ (key, value) = msg.map_string_string.items()[0]
+
+ self.assertEqual(key, unicode_obj)
+ self.assertEqual(value, unicode_obj)
+
+ self.assertTrue(isinstance(key, unicode))
+ self.assertTrue(isinstance(value, unicode))
+
+ def testMessageMap(self):
+ msg = map_unittest_pb2.TestMap()
+
+ self.assertEqual(0, len(msg.map_int32_foreign_message))
+ self.assertFalse(5 in msg.map_int32_foreign_message)
+
+ msg.map_int32_foreign_message[123]
+ # get_or_create() is an alias for getitem.
+ msg.map_int32_foreign_message.get_or_create(-456)
+
+ self.assertEqual(2, len(msg.map_int32_foreign_message))
+ self.assertIn(123, msg.map_int32_foreign_message)
+ self.assertIn(-456, msg.map_int32_foreign_message)
+ self.assertEqual(2, len(msg.map_int32_foreign_message))
+
+ # Bad key.
+ with self.assertRaises(TypeError):
+ msg.map_int32_foreign_message['123']
+
+ # Can't assign directly to submessage.
+ with self.assertRaises(ValueError):
+ msg.map_int32_foreign_message[999] = msg.map_int32_foreign_message[123]
+
+ # Verify that trying to assign a bad key doesn't actually add a member to
+ # the map.
+ self.assertEqual(2, len(msg.map_int32_foreign_message))
+
+ serialized = msg.SerializeToString()
+ msg2 = map_unittest_pb2.TestMap()
+ msg2.ParseFromString(serialized)
+
+ self.assertEqual(2, len(msg2.map_int32_foreign_message))
+ self.assertIn(123, msg2.map_int32_foreign_message)
+ self.assertIn(-456, msg2.map_int32_foreign_message)
+ self.assertEqual(2, len(msg2.map_int32_foreign_message))
+
+ def testMergeFrom(self):
+ msg = map_unittest_pb2.TestMap()
+ msg.map_int32_int32[12] = 34
+ msg.map_int32_int32[56] = 78
+ msg.map_int64_int64[22] = 33
+ msg.map_int32_foreign_message[111].c = 5
+ msg.map_int32_foreign_message[222].c = 10
+
+ msg2 = map_unittest_pb2.TestMap()
+ msg2.map_int32_int32[12] = 55
+ msg2.map_int64_int64[88] = 99
+ msg2.map_int32_foreign_message[222].c = 15
+
+ msg2.MergeFrom(msg)
+
+ self.assertEqual(34, msg2.map_int32_int32[12])
+ self.assertEqual(78, msg2.map_int32_int32[56])
+ self.assertEqual(33, msg2.map_int64_int64[22])
+ self.assertEqual(99, msg2.map_int64_int64[88])
+ self.assertEqual(5, msg2.map_int32_foreign_message[111].c)
+ self.assertEqual(10, msg2.map_int32_foreign_message[222].c)
+
+ # Verify that there is only one entry per key, even though the MergeFrom
+ # may have internally created multiple entries for a single key in the
+ # list representation.
+ as_dict = {}
+ for key in msg2.map_int32_foreign_message:
+ self.assertFalse(key in as_dict)
+ as_dict[key] = msg2.map_int32_foreign_message[key].c
+
+ self.assertEqual({111: 5, 222: 10}, as_dict)
+
+ # Special case: test that delete of item really removes the item, even if
+ # there might have physically been duplicate keys due to the previous merge.
+ # This is only a special case for the C++ implementation which stores the
+ # map as an array.
+ del msg2.map_int32_int32[12]
+ self.assertFalse(12 in msg2.map_int32_int32)
+
+ del msg2.map_int32_foreign_message[222]
+ self.assertFalse(222 in msg2.map_int32_foreign_message)
+
+ def testIntegerMapWithLongs(self):
+ msg = map_unittest_pb2.TestMap()
+ msg.map_int32_int32[long(-123)] = long(-456)
+ msg.map_int64_int64[long(-2**33)] = long(-2**34)
+ msg.map_uint32_uint32[long(123)] = long(456)
+ msg.map_uint64_uint64[long(2**33)] = long(2**34)
+
+ serialized = msg.SerializeToString()
+ msg2 = map_unittest_pb2.TestMap()
+ msg2.ParseFromString(serialized)
+
+ self.assertEqual(-456, msg2.map_int32_int32[-123])
+ self.assertEqual(-2**34, msg2.map_int64_int64[-2**33])
+ self.assertEqual(456, msg2.map_uint32_uint32[123])
+ self.assertEqual(2**34, msg2.map_uint64_uint64[2**33])
+
+ def testMapAssignmentCausesPresence(self):
+ msg = map_unittest_pb2.TestMapSubmessage()
+ msg.test_map.map_int32_int32[123] = 456
+
+ serialized = msg.SerializeToString()
+ msg2 = map_unittest_pb2.TestMapSubmessage()
+ msg2.ParseFromString(serialized)
+
+ self.assertEqual(msg, msg2)
+
+ # Now test that various mutations of the map properly invalidate the
+ # cached size of the submessage.
+ msg.test_map.map_int32_int32[888] = 999
+ serialized = msg.SerializeToString()
+ msg2.ParseFromString(serialized)
+ self.assertEqual(msg, msg2)
+
+ msg.test_map.map_int32_int32.clear()
+ serialized = msg.SerializeToString()
+ msg2.ParseFromString(serialized)
+ self.assertEqual(msg, msg2)
+
+ def testMapAssignmentCausesPresenceForSubmessages(self):
+ msg = map_unittest_pb2.TestMapSubmessage()
+ msg.test_map.map_int32_foreign_message[123].c = 5
+
+ serialized = msg.SerializeToString()
+ msg2 = map_unittest_pb2.TestMapSubmessage()
+ msg2.ParseFromString(serialized)
+
+ self.assertEqual(msg, msg2)
+
+ # Now test that various mutations of the map properly invalidate the
+ # cached size of the submessage.
+ msg.test_map.map_int32_foreign_message[888].c = 7
+ serialized = msg.SerializeToString()
+ msg2.ParseFromString(serialized)
+ self.assertEqual(msg, msg2)
+
+ msg.test_map.map_int32_foreign_message[888].MergeFrom(
+ msg.test_map.map_int32_foreign_message[123])
+ serialized = msg.SerializeToString()
+ msg2.ParseFromString(serialized)
+ self.assertEqual(msg, msg2)
+
+ msg.test_map.map_int32_foreign_message.clear()
+ serialized = msg.SerializeToString()
+ msg2.ParseFromString(serialized)
+ self.assertEqual(msg, msg2)
+
+ def testModifyMapWhileIterating(self):
+ msg = map_unittest_pb2.TestMap()
+
+ string_string_iter = iter(msg.map_string_string)
+ int32_foreign_iter = iter(msg.map_int32_foreign_message)
+
+ msg.map_string_string['abc'] = '123'
+ msg.map_int32_foreign_message[5].c = 5
+
+ with self.assertRaises(RuntimeError):
+ for key in string_string_iter:
+ pass
+
+ with self.assertRaises(RuntimeError):
+ for key in int32_foreign_iter:
+ pass
+
+ def testSubmessageMap(self):
+ msg = map_unittest_pb2.TestMap()
+
+ submsg = msg.map_int32_foreign_message[111]
+ self.assertIs(submsg, msg.map_int32_foreign_message[111])
+ self.assertTrue(isinstance(submsg, unittest_pb2.ForeignMessage))
+
+ submsg.c = 5
+
+ serialized = msg.SerializeToString()
+ msg2 = map_unittest_pb2.TestMap()
+ msg2.ParseFromString(serialized)
+
+ self.assertEqual(5, msg2.map_int32_foreign_message[111].c)
+
+ # Doesn't allow direct submessage assignment.
+ with self.assertRaises(ValueError):
+ msg.map_int32_foreign_message[88] = unittest_pb2.ForeignMessage()
+
+ def testMapIteration(self):
+ msg = map_unittest_pb2.TestMap()
+
+ for k, v in msg.map_int32_int32.iteritems():
+ # Should not be reached.
+ self.assertTrue(False)
+
+ msg.map_int32_int32[2] = 4
+ msg.map_int32_int32[3] = 6
+ msg.map_int32_int32[4] = 8
+ self.assertEqual(3, len(msg.map_int32_int32))
+
+ matching_dict = {2: 4, 3: 6, 4: 8}
+ self.assertMapIterEquals(msg.map_int32_int32.iteritems(), matching_dict)
+
+ def testMapIterationClearMessage(self):
+ # Iterator needs to work even if message and map are deleted.
+ msg = map_unittest_pb2.TestMap()
+
+ msg.map_int32_int32[2] = 4
+ msg.map_int32_int32[3] = 6
+ msg.map_int32_int32[4] = 8
+
+ it = msg.map_int32_int32.iteritems()
+ del msg
+
+ matching_dict = {2: 4, 3: 6, 4: 8}
+ self.assertMapIterEquals(it, matching_dict)
+
+ def testMapConstruction(self):
+ msg = map_unittest_pb2.TestMap(map_int32_int32={1: 2, 3: 4})
+ self.assertEqual(2, msg.map_int32_int32[1])
+ self.assertEqual(4, msg.map_int32_int32[3])
+
+ msg = map_unittest_pb2.TestMap(
+ map_int32_foreign_message={3: unittest_pb2.ForeignMessage(c=5)})
+ self.assertEqual(5, msg.map_int32_foreign_message[3].c)
+
+ def testMapValidAfterFieldCleared(self):
+ # Map needs to work even if field is cleared.
+ # For the C++ implementation this tests the correctness of
+ # ScalarMapContainer::Release()
+ msg = map_unittest_pb2.TestMap()
+ map = msg.map_int32_int32
+
+ map[2] = 4
+ map[3] = 6
+ map[4] = 8
+
+ msg.ClearField('map_int32_int32')
+ matching_dict = {2: 4, 3: 6, 4: 8}
+ self.assertMapIterEquals(map.iteritems(), matching_dict)
+
+ def testMapIterValidAfterFieldCleared(self):
+ # Map iterator needs to work even if field is cleared.
+ # For the C++ implementation this tests the correctness of
+ # ScalarMapContainer::Release()
+ msg = map_unittest_pb2.TestMap()
+
+ msg.map_int32_int32[2] = 4
+ msg.map_int32_int32[3] = 6
+ msg.map_int32_int32[4] = 8
+
+ it = msg.map_int32_int32.iteritems()
+
+ msg.ClearField('map_int32_int32')
+ matching_dict = {2: 4, 3: 6, 4: 8}
+ self.assertMapIterEquals(it, matching_dict)
+
+ def testMapDelete(self):
+ msg = map_unittest_pb2.TestMap()
+
+ self.assertEqual(0, len(msg.map_int32_int32))
+
+ msg.map_int32_int32[4] = 6
+ self.assertEqual(1, len(msg.map_int32_int32))
+
+ with self.assertRaises(KeyError):
+ del msg.map_int32_int32[88]
+
+ del msg.map_int32_int32[4]
+ self.assertEqual(0, len(msg.map_int32_int32))
+
+
class ValidTypeNamesTest(unittest.TestCase):
diff --git a/python/google/protobuf/internal/proto_builder_test.py b/python/google/protobuf/internal/proto_builder_test.py
index 9229205a..edaf3fa3 100644
--- a/python/google/protobuf/internal/proto_builder_test.py
+++ b/python/google/protobuf/internal/proto_builder_test.py
@@ -1,4 +1,4 @@
-#! /usr/bin/python
+#! /usr/bin/env python
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
@@ -32,6 +32,7 @@
"""Tests for google.protobuf.proto_builder."""
+import collections
import unittest
from google.protobuf import descriptor_pb2
@@ -43,10 +44,11 @@ from google.protobuf import text_format
class ProtoBuilderTest(unittest.TestCase):
def setUp(self):
- self._fields = {
- 'foo': descriptor_pb2.FieldDescriptorProto.TYPE_INT64,
- 'bar': descriptor_pb2.FieldDescriptorProto.TYPE_STRING,
- }
+ self.ordered_fields = collections.OrderedDict([
+ ('foo', descriptor_pb2.FieldDescriptorProto.TYPE_INT64),
+ ('bar', descriptor_pb2.FieldDescriptorProto.TYPE_STRING),
+ ])
+ self._fields = dict(self.ordered_fields)
def testMakeSimpleProtoClass(self):
"""Test that we can create a proto class."""
@@ -59,6 +61,17 @@ class ProtoBuilderTest(unittest.TestCase):
self.assertMultiLineEqual(
'bar: "asdf"\nfoo: 12345\n', text_format.MessageToString(proto))
+ def testOrderedFields(self):
+ """Test that the field order is maintained when given an OrderedDict."""
+ proto_cls = proto_builder.MakeSimpleProtoClass(
+ self.ordered_fields,
+ full_name='net.proto2.python.public.proto_builder_test.OrderedTest')
+ proto = proto_cls()
+ proto.foo = 12345
+ proto.bar = 'asdf'
+ self.assertMultiLineEqual(
+ 'foo: 12345\nbar: "asdf"\n', text_format.MessageToString(proto))
+
def testMakeSameProtoClassTwice(self):
"""Test that the DescriptorPool is used."""
pool = descriptor_pool.DescriptorPool()
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py
index 58c65db9..bb06beb3 100755
--- a/python/google/protobuf/internal/python_message.py
+++ b/python/google/protobuf/internal/python_message.py
@@ -59,6 +59,7 @@ import weakref
import six
import six.moves.copyreg as copyreg
+import six.string_types
# We use "as" to avoid name collisions with variables.
from google.protobuf.internal import containers
@@ -70,6 +71,7 @@ from google.protobuf.internal import type_checkers
from google.protobuf.internal import wire_format
from google.protobuf import descriptor as descriptor_mod
from google.protobuf import message as message_mod
+from google.protobuf import symbol_database
from google.protobuf import text_format
_FieldDescriptor = descriptor_mod.FieldDescriptor
@@ -94,6 +96,7 @@ def InitMessage(descriptor, cls):
for field in descriptor.fields:
_AttachFieldHelpers(cls, field)
+ descriptor._concrete_class = cls # pylint: disable=protected-access
_AddEnumValues(descriptor, cls)
_AddInitMethod(descriptor, cls)
_AddPropertiesForFields(descriptor, cls)
@@ -191,12 +194,37 @@ def _IsMessageSetExtension(field):
field.label == _FieldDescriptor.LABEL_OPTIONAL)
+def _IsMapField(field):
+ return (field.type == _FieldDescriptor.TYPE_MESSAGE and
+ field.message_type.has_options and
+ field.message_type.GetOptions().map_entry)
+
+
+def _IsMessageMapField(field):
+ value_type = field.message_type.fields_by_name["value"]
+ return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
+
+
def _AttachFieldHelpers(cls, field_descriptor):
is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
- is_packed = (field_descriptor.has_options and
- field_descriptor.GetOptions().packed)
-
- if _IsMessageSetExtension(field_descriptor):
+ is_packable = (is_repeated and
+ wire_format.IsTypePackable(field_descriptor.type))
+ if not is_packable:
+ is_packed = False
+ elif field_descriptor.containing_type.syntax == "proto2":
+ is_packed = (field_descriptor.has_options and
+ field_descriptor.GetOptions().packed)
+ else:
+ has_packed_false = (field_descriptor.has_options and
+ field_descriptor.GetOptions().HasField("packed") and
+ field_descriptor.GetOptions().packed == False)
+ is_packed = not has_packed_false
+ is_map_entry = _IsMapField(field_descriptor)
+
+ if is_map_entry:
+ field_encoder = encoder.MapEncoder(field_descriptor)
+ sizer = encoder.MapSizer(field_descriptor)
+ elif _IsMessageSetExtension(field_descriptor):
field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
sizer = encoder.MessageSetItemSizer(field_descriptor.number)
else:
@@ -212,12 +240,27 @@ def _AttachFieldHelpers(cls, field_descriptor):
def AddDecoder(wiretype, is_packed):
tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
- cls._decoders_by_tag[tag_bytes] = (
- type_checkers.TYPE_TO_DECODER[field_descriptor.type](
- field_descriptor.number, is_repeated, is_packed,
- field_descriptor, field_descriptor._default_constructor),
- field_descriptor if field_descriptor.containing_oneof is not None
- else None)
+ decode_type = field_descriptor.type
+ if (decode_type == _FieldDescriptor.TYPE_ENUM and
+ type_checkers.SupportsOpenEnums(field_descriptor)):
+ decode_type = _FieldDescriptor.TYPE_INT32
+
+ oneof_descriptor = None
+ if field_descriptor.containing_oneof is not None:
+ oneof_descriptor = field_descriptor
+
+ if is_map_entry:
+ is_message_map = _IsMessageMapField(field_descriptor)
+
+ field_decoder = decoder.MapDecoder(
+ field_descriptor, _GetInitializeDefaultForMap(field_descriptor),
+ is_message_map)
+ else:
+ field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
+ field_descriptor.number, is_repeated, is_packed,
+ field_descriptor, field_descriptor._default_constructor)
+
+ cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor)
AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
False)
@@ -250,6 +293,26 @@ def _AddEnumValues(descriptor, cls):
setattr(cls, enum_value.name, enum_value.number)
+def _GetInitializeDefaultForMap(field):
+ if field.label != _FieldDescriptor.LABEL_REPEATED:
+ raise ValueError('map_entry set on non-repeated field %s' % (
+ field.name))
+ fields_by_name = field.message_type.fields_by_name
+ key_checker = type_checkers.GetTypeChecker(fields_by_name['key'])
+
+ value_field = fields_by_name['value']
+ if _IsMessageMapField(field):
+ def MakeMessageMapDefault(message):
+ return containers.MessageMap(
+ message._listener_for_children, value_field.message_type, key_checker)
+ return MakeMessageMapDefault
+ else:
+ value_checker = type_checkers.GetTypeChecker(value_field)
+ def MakePrimitiveMapDefault(message):
+ return containers.ScalarMap(
+ message._listener_for_children, key_checker, value_checker)
+ return MakePrimitiveMapDefault
+
def _DefaultValueConstructorForField(field):
"""Returns a function which returns a default value for a field.
@@ -264,6 +327,9 @@ def _DefaultValueConstructorForField(field):
value may refer back to |message| via a weak reference.
"""
+ if _IsMapField(field):
+ return _GetInitializeDefaultForMap(field)
+
if field.label == _FieldDescriptor.LABEL_REPEATED:
if field.has_default_value and field.default_value != []:
raise ValueError('Repeated field default value not empty list: %s' % (
@@ -289,6 +355,8 @@ def _DefaultValueConstructorForField(field):
def MakeSubMessageDefault(message):
result = message_type._concrete_class()
result._SetListener(message._listener_for_children)
+ if field.containing_oneof:
+ message._UpdateOneofState(field)
return result
return MakeSubMessageDefault
@@ -312,7 +380,22 @@ def _ReraiseTypeErrorWithFieldName(message_name, field_name):
def _AddInitMethod(message_descriptor, cls):
"""Adds an __init__ method to cls."""
- fields = message_descriptor.fields
+
+ def _GetIntegerEnumValue(enum_type, value):
+ """Convert a string or integer enum value to an integer.
+
+ If the value is a string, it is converted to the enum value in
+ enum_type with the same name. If the value is not a string, it's
+ returned as-is. (No conversion or bounds-checking is done.)
+ """
+ if isinstance(value, six.string_types):
+ try:
+ return enum_type.values_by_name[value].number
+ except KeyError:
+ raise ValueError('Enum type %s: unknown label "%s"' % (
+ enum_type.full_name, value))
+ return value
+
def init(self, **kwargs):
self._cached_byte_size = 0
self._cached_byte_size_dirty = len(kwargs) > 0
@@ -335,19 +418,37 @@ def _AddInitMethod(message_descriptor, cls):
if field.label == _FieldDescriptor.LABEL_REPEATED:
copy = field._default_constructor(self)
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
- for val in field_value:
- copy.add().MergeFrom(val)
+ if _IsMapField(field):
+ if _IsMessageMapField(field):
+ for key in field_value:
+ copy[key].MergeFrom(field_value[key])
+ else:
+ copy.update(field_value)
+ else:
+ for val in field_value:
+ if isinstance(val, dict):
+ copy.add(**val)
+ else:
+ copy.add().MergeFrom(val)
else: # Scalar
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
+ field_value = [_GetIntegerEnumValue(field.enum_type, val)
+ for val in field_value]
copy.extend(field_value)
self._fields[field] = copy
elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
copy = field._default_constructor(self)
+ new_val = field_value
+ if isinstance(field_value, dict):
+ new_val = field.message_type._concrete_class(**field_value)
try:
- copy.MergeFrom(field_value)
+ copy.MergeFrom(new_val)
except TypeError:
_ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
self._fields[field] = copy
else:
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
+ field_value = _GetIntegerEnumValue(field.enum_type, field_value)
try:
setattr(self, field_name, field_value)
except TypeError:
@@ -469,6 +570,7 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls):
type_checker = type_checkers.GetTypeChecker(field)
default_value = field.default_value
valid_values = set()
+ is_proto3 = field.containing_type.syntax == "proto3"
def getter(self):
# TODO(protobuf-team): This may be broken since there may not be
@@ -476,15 +578,24 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls):
return self._fields.get(field, default_value)
getter.__module__ = None
getter.__doc__ = 'Getter for %s.' % proto_field_name
+
+ clear_when_set_to_default = is_proto3 and not field.containing_oneof
+
def field_setter(self, new_value):
# pylint: disable=protected-access
- self._fields[field] = type_checker.CheckValue(new_value)
+ # Testing the value for truthiness captures all of the proto3 defaults
+ # (0, 0.0, enum 0, and False).
+ new_value = type_checker.CheckValue(new_value)
+ if clear_when_set_to_default and not new_value:
+ self._fields.pop(field, None)
+ else:
+ self._fields[field] = new_value
# Check _cached_byte_size_dirty inline to improve performance, since scalar
# setters are called frequently.
if not self._cached_byte_size_dirty:
self._Modified()
- if field.containing_oneof is not None:
+ if field.containing_oneof:
def setter(self, new_value):
field_setter(self, new_value)
self._UpdateOneofState(field)
@@ -617,24 +728,35 @@ def _AddListFieldsMethod(message_descriptor, cls):
cls.ListFields = ListFields
+_Proto3HasError = 'Protocol message has no non-repeated submessage field "%s"'
+_Proto2HasError = 'Protocol message has no non-repeated field "%s"'
def _AddHasFieldMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
- singular_fields = {}
+ is_proto3 = (message_descriptor.syntax == "proto3")
+ error_msg = _Proto3HasError if is_proto3 else _Proto2HasError
+
+ hassable_fields = {}
for field in message_descriptor.fields:
- if field.label != _FieldDescriptor.LABEL_REPEATED:
- singular_fields[field.name] = field
- # Fields inside oneofs are never repeated (enforced by the compiler).
- for field in message_descriptor.oneofs:
- singular_fields[field.name] = field
+ if field.label == _FieldDescriptor.LABEL_REPEATED:
+ continue
+ # For proto3, only submessages and fields inside a oneof have presence.
+ if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and
+ not field.containing_oneof):
+ continue
+ hassable_fields[field.name] = field
+
+ if not is_proto3:
+ # Fields inside oneofs are never repeated (enforced by the compiler).
+ for oneof in message_descriptor.oneofs:
+ hassable_fields[oneof.name] = oneof
def HasField(self, field_name):
try:
- field = singular_fields[field_name]
+ field = hassable_fields[field_name]
except KeyError:
- raise ValueError(
- 'Protocol message has no singular "%s" field.' % field_name)
+ raise ValueError(error_msg % field_name)
if isinstance(field, descriptor_mod.OneofDescriptor):
try:
@@ -720,6 +842,26 @@ def _AddHasExtensionMethod(cls):
return extension_handle in self._fields
cls.HasExtension = HasExtension
+def _UnpackAny(msg):
+ type_url = msg.type_url
+ db = symbol_database.Default()
+
+ if not type_url:
+ return None
+
+ # TODO(haberman): For now we just strip the hostname. Better logic will be
+ # required.
+ type_name = type_url.split("/")[-1]
+ descriptor = db.pool.FindMessageTypeByName(type_name)
+
+ if descriptor is None:
+ return None
+
+ message_class = db.GetPrototype(descriptor)
+ message = message_class()
+
+ message.ParseFromString(msg.value)
+ return message
def _AddEqualsMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
@@ -731,6 +873,12 @@ def _AddEqualsMethod(message_descriptor, cls):
if self is other:
return True
+ if self.DESCRIPTOR.full_name == "google.protobuf.Any":
+ any_a = _UnpackAny(self)
+ any_b = _UnpackAny(other)
+ if any_a and any_b:
+ return any_a == any_b
+
if not self.ListFields() == other.ListFields():
return False
@@ -864,6 +1012,7 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
local_ReadTag = decoder.ReadTag
local_SkipField = decoder.SkipField
decoders_by_tag = cls._decoders_by_tag
+ is_proto3 = message_descriptor.syntax == "proto3"
def InternalParse(self, buffer, pos, end):
self._Modified()
@@ -877,9 +1026,11 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
if new_pos == -1:
return pos
- if not unknown_field_list:
- unknown_field_list = self._unknown_fields = []
- unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos]))
+ if not is_proto3:
+ if not unknown_field_list:
+ unknown_field_list = self._unknown_fields = []
+ unknown_field_list.append(
+ (tag_bytes, buffer[value_start_pos:new_pos]))
pos = new_pos
else:
pos = field_decoder(buffer, new_pos, end, self, field_dict)
@@ -920,6 +1071,9 @@ def _AddIsInitializedMethod(message_descriptor, cls):
for field, value in list(self._fields.items()): # dict can change size!
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
if field.label == _FieldDescriptor.LABEL_REPEATED:
+ if (field.message_type.has_options and
+ field.message_type.GetOptions().map_entry):
+ continue
for element in value:
if not element.IsInitialized():
if errors is not None:
@@ -955,16 +1109,26 @@ def _AddIsInitializedMethod(message_descriptor, cls):
else:
name = field.name
- if field.label == _FieldDescriptor.LABEL_REPEATED:
- for i in range(len(value)):
+ if _IsMapField(field):
+ if _IsMessageMapField(field):
+ for key in value:
+ element = value[key]
+ prefix = "%s[%d]." % (name, key)
+ sub_errors = element.FindInitializationErrors()
+ errors += [prefix + error for error in sub_errors]
+ else:
+ # ScalarMaps can't have any initialization errors.
+ pass
+ elif field.label == _FieldDescriptor.LABEL_REPEATED:
+ for i in xrange(len(value)):
element = value[i]
prefix = "%s[%d]." % (name, i)
sub_errors = element.FindInitializationErrors()
- errors += [ prefix + error for error in sub_errors ]
+ errors += [prefix + error for error in sub_errors]
else:
prefix = name + "."
sub_errors = value.FindInitializationErrors()
- errors += [ prefix + error for error in sub_errors ]
+ errors += [prefix + error for error in sub_errors]
return errors
@@ -1001,6 +1165,8 @@ def _AddMergeFromMethod(cls):
# Construct a new object to represent this field.
field_value = field._default_constructor(self)
fields[field] = field_value
+ if field.containing_oneof:
+ self._UpdateOneofState(field)
field_value.MergeFrom(value)
else:
self._fields[field] = value
@@ -1245,11 +1411,10 @@ class _ExtensionDict(object):
# It's slightly wasteful to lookup the type checker each time,
# but we expect this to be a vanishingly uncommon case anyway.
- type_checker = type_checkers.GetTypeChecker(
- extension_handle)
+ type_checker = type_checkers.GetTypeChecker(extension_handle)
# pylint: disable=protected-access
self._extended_message._fields[extension_handle] = (
- type_checker.CheckValue(value))
+ type_checker.CheckValue(value))
self._extended_message._Modified()
def _FindExtensionByName(self, name):
diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py
index a3757992..794395c5 100755
--- a/python/google/protobuf/internal/reflection_test.py
+++ b/python/google/protobuf/internal/reflection_test.py
@@ -1,4 +1,4 @@
-#! /usr/bin/python
+#! /usr/bin/env python
# -*- coding: utf-8 -*-
#
# Protocol Buffers - Google's data interchange format
@@ -1634,7 +1634,7 @@ class ReflectionTest(unittest.TestCase):
self.assertFalse(proto.IsInitialized(errors))
self.assertEqual(errors, ['a', 'b', 'c'])
- @skipIf(
+ @basetest.unittest.skipIf(
api_implementation.Type() != 'cpp' or api_implementation.Version() != 2,
'Errors are only available from the most recent C++ implementation.')
def testFileDescriptorErrors(self):
@@ -1665,6 +1665,7 @@ class ReflectionTest(unittest.TestCase):
else:
self.fail("Did not raise TypeError")
+ self.assertTrue('test_file_descriptor_errors.msg1' in message)
self.assertTrue('test_file_descriptor_errors.proto' in message)
def testStringUTF8Encoding(self):
diff --git a/python/google/protobuf/internal/service_reflection_test.py b/python/google/protobuf/internal/service_reflection_test.py
index e3f71545..9967255a 100755
--- a/python/google/protobuf/internal/service_reflection_test.py
+++ b/python/google/protobuf/internal/service_reflection_test.py
@@ -1,4 +1,4 @@
-#! /usr/bin/python
+#! /usr/bin/env python
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
@@ -35,7 +35,6 @@
__author__ = 'petar@google.com (Petar Petrov)'
import unittest
-
from google.protobuf import unittest_pb2
from google.protobuf import service_reflection
from google.protobuf import service
diff --git a/python/google/protobuf/internal/symbol_database_test.py b/python/google/protobuf/internal/symbol_database_test.py
index bbe602b3..b2489cdb 100644
--- a/python/google/protobuf/internal/symbol_database_test.py
+++ b/python/google/protobuf/internal/symbol_database_test.py
@@ -1,4 +1,4 @@
-#! /usr/bin/python
+#! /usr/bin/env python
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
@@ -33,7 +33,6 @@
"""Tests for google.protobuf.symbol_database."""
import unittest
-
from google.protobuf import unittest_pb2
from google.protobuf import symbol_database
diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py
index 787f4650..fec65382 100755
--- a/python/google/protobuf/internal/test_util.py
+++ b/python/google/protobuf/internal/test_util.py
@@ -40,13 +40,19 @@ import os.path
from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_pb2
+from google.protobuf import descriptor_pb2
+# Tests whether the given TestAllTypes message is proto2 or not.
+# This is used to gate several fields/features that only exist
+# for the proto2 version of the message.
+def IsProto2(message):
+ return message.DESCRIPTOR.syntax == "proto2"
def SetAllNonLazyFields(message):
"""Sets every non-lazy field in the message to a unique value.
Args:
- message: A unittest_pb2.TestAllTypes instance.
+ message: A TestAllTypes instance.
"""
#
@@ -69,7 +75,8 @@ def SetAllNonLazyFields(message):
message.optional_string = u'115'
message.optional_bytes = b'116'
- message.optionalgroup.a = 117
+ if IsProto2(message):
+ message.optionalgroup.a = 117
message.optional_nested_message.bb = 118
message.optional_foreign_message.c = 119
message.optional_import_message.d = 120
@@ -77,7 +84,8 @@ def SetAllNonLazyFields(message):
message.optional_nested_enum = unittest_pb2.TestAllTypes.BAZ
message.optional_foreign_enum = unittest_pb2.FOREIGN_BAZ
- message.optional_import_enum = unittest_import_pb2.IMPORT_BAZ
+ if IsProto2(message):
+ message.optional_import_enum = unittest_import_pb2.IMPORT_BAZ
message.optional_string_piece = u'124'
message.optional_cord = u'125'
@@ -102,7 +110,8 @@ def SetAllNonLazyFields(message):
message.repeated_string.append(u'215')
message.repeated_bytes.append(b'216')
- message.repeatedgroup.add().a = 217
+ if IsProto2(message):
+ message.repeatedgroup.add().a = 217
message.repeated_nested_message.add().bb = 218
message.repeated_foreign_message.add().c = 219
message.repeated_import_message.add().d = 220
@@ -110,7 +119,8 @@ def SetAllNonLazyFields(message):
message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAR)
message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAR)
- message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAR)
+ if IsProto2(message):
+ message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAR)
message.repeated_string_piece.append(u'224')
message.repeated_cord.append(u'225')
@@ -132,7 +142,8 @@ def SetAllNonLazyFields(message):
message.repeated_string.append(u'315')
message.repeated_bytes.append(b'316')
- message.repeatedgroup.add().a = 317
+ if IsProto2(message):
+ message.repeatedgroup.add().a = 317
message.repeated_nested_message.add().bb = 318
message.repeated_foreign_message.add().c = 319
message.repeated_import_message.add().d = 320
@@ -140,7 +151,8 @@ def SetAllNonLazyFields(message):
message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAZ)
message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAZ)
- message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAZ)
+ if IsProto2(message):
+ message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAZ)
message.repeated_string_piece.append(u'324')
message.repeated_cord.append(u'325')
@@ -149,28 +161,29 @@ def SetAllNonLazyFields(message):
# Fields that have defaults.
#
- message.default_int32 = 401
- message.default_int64 = 402
- message.default_uint32 = 403
- message.default_uint64 = 404
- message.default_sint32 = 405
- message.default_sint64 = 406
- message.default_fixed32 = 407
- message.default_fixed64 = 408
- message.default_sfixed32 = 409
- message.default_sfixed64 = 410
- message.default_float = 411
- message.default_double = 412
- message.default_bool = False
- message.default_string = '415'
- message.default_bytes = b'416'
-
- message.default_nested_enum = unittest_pb2.TestAllTypes.FOO
- message.default_foreign_enum = unittest_pb2.FOREIGN_FOO
- message.default_import_enum = unittest_import_pb2.IMPORT_FOO
-
- message.default_string_piece = '424'
- message.default_cord = '425'
+ if IsProto2(message):
+ message.default_int32 = 401
+ message.default_int64 = 402
+ message.default_uint32 = 403
+ message.default_uint64 = 404
+ message.default_sint32 = 405
+ message.default_sint64 = 406
+ message.default_fixed32 = 407
+ message.default_fixed64 = 408
+ message.default_sfixed32 = 409
+ message.default_sfixed64 = 410
+ message.default_float = 411
+ message.default_double = 412
+ message.default_bool = False
+ message.default_string = '415'
+ message.default_bytes = b'416'
+
+ message.default_nested_enum = unittest_pb2.TestAllTypes.FOO
+ message.default_foreign_enum = unittest_pb2.FOREIGN_FOO
+ message.default_import_enum = unittest_import_pb2.IMPORT_FOO
+
+ message.default_string_piece = '424'
+ message.default_cord = '425'
message.oneof_uint32 = 601
message.oneof_nested_message.bb = 602
@@ -386,7 +399,8 @@ def ExpectAllFieldsSet(test_case, message):
test_case.assertTrue(message.HasField('optional_string'))
test_case.assertTrue(message.HasField('optional_bytes'))
- test_case.assertTrue(message.HasField('optionalgroup'))
+ if IsProto2(message):
+ test_case.assertTrue(message.HasField('optionalgroup'))
test_case.assertTrue(message.HasField('optional_nested_message'))
test_case.assertTrue(message.HasField('optional_foreign_message'))
test_case.assertTrue(message.HasField('optional_import_message'))
@@ -398,7 +412,8 @@ def ExpectAllFieldsSet(test_case, message):
test_case.assertTrue(message.HasField('optional_nested_enum'))
test_case.assertTrue(message.HasField('optional_foreign_enum'))
- test_case.assertTrue(message.HasField('optional_import_enum'))
+ if IsProto2(message):
+ test_case.assertTrue(message.HasField('optional_import_enum'))
test_case.assertTrue(message.HasField('optional_string_piece'))
test_case.assertTrue(message.HasField('optional_cord'))
@@ -419,7 +434,8 @@ def ExpectAllFieldsSet(test_case, message):
test_case.assertEqual('115', message.optional_string)
test_case.assertEqual(b'116', message.optional_bytes)
- test_case.assertEqual(117, message.optionalgroup.a)
+ if IsProto2(message):
+ test_case.assertEqual(117, message.optionalgroup.a)
test_case.assertEqual(118, message.optional_nested_message.bb)
test_case.assertEqual(119, message.optional_foreign_message.c)
test_case.assertEqual(120, message.optional_import_message.d)
@@ -430,8 +446,9 @@ def ExpectAllFieldsSet(test_case, message):
message.optional_nested_enum)
test_case.assertEqual(unittest_pb2.FOREIGN_BAZ,
message.optional_foreign_enum)
- test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ,
- message.optional_import_enum)
+ if IsProto2(message):
+ test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ,
+ message.optional_import_enum)
# -----------------------------------------------------------------
@@ -451,13 +468,15 @@ def ExpectAllFieldsSet(test_case, message):
test_case.assertEqual(2, len(message.repeated_string))
test_case.assertEqual(2, len(message.repeated_bytes))
- test_case.assertEqual(2, len(message.repeatedgroup))
+ if IsProto2(message):
+ test_case.assertEqual(2, len(message.repeatedgroup))
test_case.assertEqual(2, len(message.repeated_nested_message))
test_case.assertEqual(2, len(message.repeated_foreign_message))
test_case.assertEqual(2, len(message.repeated_import_message))
test_case.assertEqual(2, len(message.repeated_nested_enum))
test_case.assertEqual(2, len(message.repeated_foreign_enum))
- test_case.assertEqual(2, len(message.repeated_import_enum))
+ if IsProto2(message):
+ test_case.assertEqual(2, len(message.repeated_import_enum))
test_case.assertEqual(2, len(message.repeated_string_piece))
test_case.assertEqual(2, len(message.repeated_cord))
@@ -478,7 +497,8 @@ def ExpectAllFieldsSet(test_case, message):
test_case.assertEqual('215', message.repeated_string[0])
test_case.assertEqual(b'216', message.repeated_bytes[0])
- test_case.assertEqual(217, message.repeatedgroup[0].a)
+ if IsProto2(message):
+ test_case.assertEqual(217, message.repeatedgroup[0].a)
test_case.assertEqual(218, message.repeated_nested_message[0].bb)
test_case.assertEqual(219, message.repeated_foreign_message[0].c)
test_case.assertEqual(220, message.repeated_import_message[0].d)
@@ -488,8 +508,9 @@ def ExpectAllFieldsSet(test_case, message):
message.repeated_nested_enum[0])
test_case.assertEqual(unittest_pb2.FOREIGN_BAR,
message.repeated_foreign_enum[0])
- test_case.assertEqual(unittest_import_pb2.IMPORT_BAR,
- message.repeated_import_enum[0])
+ if IsProto2(message):
+ test_case.assertEqual(unittest_import_pb2.IMPORT_BAR,
+ message.repeated_import_enum[0])
test_case.assertEqual(301, message.repeated_int32[1])
test_case.assertEqual(302, message.repeated_int64[1])
@@ -507,7 +528,8 @@ def ExpectAllFieldsSet(test_case, message):
test_case.assertEqual('315', message.repeated_string[1])
test_case.assertEqual(b'316', message.repeated_bytes[1])
- test_case.assertEqual(317, message.repeatedgroup[1].a)
+ if IsProto2(message):
+ test_case.assertEqual(317, message.repeatedgroup[1].a)
test_case.assertEqual(318, message.repeated_nested_message[1].bb)
test_case.assertEqual(319, message.repeated_foreign_message[1].c)
test_case.assertEqual(320, message.repeated_import_message[1].d)
@@ -517,53 +539,55 @@ def ExpectAllFieldsSet(test_case, message):
message.repeated_nested_enum[1])
test_case.assertEqual(unittest_pb2.FOREIGN_BAZ,
message.repeated_foreign_enum[1])
- test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ,
- message.repeated_import_enum[1])
+ if IsProto2(message):
+ test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ,
+ message.repeated_import_enum[1])
# -----------------------------------------------------------------
- test_case.assertTrue(message.HasField('default_int32'))
- test_case.assertTrue(message.HasField('default_int64'))
- test_case.assertTrue(message.HasField('default_uint32'))
- test_case.assertTrue(message.HasField('default_uint64'))
- test_case.assertTrue(message.HasField('default_sint32'))
- test_case.assertTrue(message.HasField('default_sint64'))
- test_case.assertTrue(message.HasField('default_fixed32'))
- test_case.assertTrue(message.HasField('default_fixed64'))
- test_case.assertTrue(message.HasField('default_sfixed32'))
- test_case.assertTrue(message.HasField('default_sfixed64'))
- test_case.assertTrue(message.HasField('default_float'))
- test_case.assertTrue(message.HasField('default_double'))
- test_case.assertTrue(message.HasField('default_bool'))
- test_case.assertTrue(message.HasField('default_string'))
- test_case.assertTrue(message.HasField('default_bytes'))
-
- test_case.assertTrue(message.HasField('default_nested_enum'))
- test_case.assertTrue(message.HasField('default_foreign_enum'))
- test_case.assertTrue(message.HasField('default_import_enum'))
-
- test_case.assertEqual(401, message.default_int32)
- test_case.assertEqual(402, message.default_int64)
- test_case.assertEqual(403, message.default_uint32)
- test_case.assertEqual(404, message.default_uint64)
- test_case.assertEqual(405, message.default_sint32)
- test_case.assertEqual(406, message.default_sint64)
- test_case.assertEqual(407, message.default_fixed32)
- test_case.assertEqual(408, message.default_fixed64)
- test_case.assertEqual(409, message.default_sfixed32)
- test_case.assertEqual(410, message.default_sfixed64)
- test_case.assertEqual(411, message.default_float)
- test_case.assertEqual(412, message.default_double)
- test_case.assertEqual(False, message.default_bool)
- test_case.assertEqual('415', message.default_string)
- test_case.assertEqual(b'416', message.default_bytes)
-
- test_case.assertEqual(unittest_pb2.TestAllTypes.FOO,
- message.default_nested_enum)
- test_case.assertEqual(unittest_pb2.FOREIGN_FOO,
- message.default_foreign_enum)
- test_case.assertEqual(unittest_import_pb2.IMPORT_FOO,
- message.default_import_enum)
+ if IsProto2(message):
+ test_case.assertTrue(message.HasField('default_int32'))
+ test_case.assertTrue(message.HasField('default_int64'))
+ test_case.assertTrue(message.HasField('default_uint32'))
+ test_case.assertTrue(message.HasField('default_uint64'))
+ test_case.assertTrue(message.HasField('default_sint32'))
+ test_case.assertTrue(message.HasField('default_sint64'))
+ test_case.assertTrue(message.HasField('default_fixed32'))
+ test_case.assertTrue(message.HasField('default_fixed64'))
+ test_case.assertTrue(message.HasField('default_sfixed32'))
+ test_case.assertTrue(message.HasField('default_sfixed64'))
+ test_case.assertTrue(message.HasField('default_float'))
+ test_case.assertTrue(message.HasField('default_double'))
+ test_case.assertTrue(message.HasField('default_bool'))
+ test_case.assertTrue(message.HasField('default_string'))
+ test_case.assertTrue(message.HasField('default_bytes'))
+
+ test_case.assertTrue(message.HasField('default_nested_enum'))
+ test_case.assertTrue(message.HasField('default_foreign_enum'))
+ test_case.assertTrue(message.HasField('default_import_enum'))
+
+ test_case.assertEqual(401, message.default_int32)
+ test_case.assertEqual(402, message.default_int64)
+ test_case.assertEqual(403, message.default_uint32)
+ test_case.assertEqual(404, message.default_uint64)
+ test_case.assertEqual(405, message.default_sint32)
+ test_case.assertEqual(406, message.default_sint64)
+ test_case.assertEqual(407, message.default_fixed32)
+ test_case.assertEqual(408, message.default_fixed64)
+ test_case.assertEqual(409, message.default_sfixed32)
+ test_case.assertEqual(410, message.default_sfixed64)
+ test_case.assertEqual(411, message.default_float)
+ test_case.assertEqual(412, message.default_double)
+ test_case.assertEqual(False, message.default_bool)
+ test_case.assertEqual('415', message.default_string)
+ test_case.assertEqual(b'416', message.default_bytes)
+
+ test_case.assertEqual(unittest_pb2.TestAllTypes.FOO,
+ message.default_nested_enum)
+ test_case.assertEqual(unittest_pb2.FOREIGN_FOO,
+ message.default_foreign_enum)
+ test_case.assertEqual(unittest_import_pb2.IMPORT_FOO,
+ message.default_import_enum)
def GoldenFile(filename):
@@ -578,6 +602,13 @@ def GoldenFile(filename):
return open(full_path, 'rb')
path = os.path.join(path, '..')
+ # Search internally.
+ path = '.'
+ full_path = os.path.join(path, 'third_party/py/google/protobuf/testdata', filename)
+ if os.path.exists(full_path):
+ # Found it. Load the golden file from the testdata directory.
+ return open(full_path, 'rb')
+
raise RuntimeError(
'Could not find golden files. This test must be run from within the '
'protobuf source package so that it can read test data files from the '
@@ -594,7 +625,7 @@ def SetAllPackedFields(message):
"""Sets every field in the message to a unique value.
Args:
- message: A unittest_pb2.TestPackedTypes instance.
+ message: A TestPackedTypes instance.
"""
message.packed_int32.extend([601, 701])
message.packed_int64.extend([602, 702])
diff --git a/python/google/protobuf/internal/text_encoding_test.py b/python/google/protobuf/internal/text_encoding_test.py
index fbd50bb8..9e7b9ce4 100755
--- a/python/google/protobuf/internal/text_encoding_test.py
+++ b/python/google/protobuf/internal/text_encoding_test.py
@@ -1,4 +1,4 @@
-#! /usr/bin/python
+#! /usr/bin/env python
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
@@ -33,7 +33,6 @@
"""Tests for google.protobuf.text_encoding."""
import unittest
-
from google.protobuf import text_encoding
TEST_VALUES = [
diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py
index eda38ae9..55b32249 100755
--- a/python/google/protobuf/internal/text_format_test.py
+++ b/python/google/protobuf/internal/text_format_test.py
@@ -1,4 +1,4 @@
-#! /usr/bin/python
+#! /usr/bin/env python
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
@@ -35,16 +35,22 @@
__author__ = 'kenton@google.com (Kenton Varda)'
import re
-import unittest
import six
-from google.protobuf import text_format
-from google.protobuf.internal import test_util
-from google.protobuf import unittest_pb2
+import unittest
+from google.protobuf.internal import _parameterized
+
+from google.protobuf import map_unittest_pb2
from google.protobuf import unittest_mset_pb2
+from google.protobuf import unittest_pb2
+from google.protobuf import unittest_proto3_arena_pb2
+from google.protobuf.internal import api_implementation
+from google.protobuf.internal import test_util
+from google.protobuf import text_format
-class TextFormatTest(unittest.TestCase):
+# Base class with some common functionality.
+class TextFormatBase(unittest.TestCase):
def ReadGolden(self, golden_filename):
with test_util.GoldenFile(golden_filename) as f:
@@ -58,73 +64,24 @@ class TextFormatTest(unittest.TestCase):
def CompareToGoldenText(self, text, golden_text):
self.assertMultiLineEqual(text, golden_text)
- def testPrintAllFields(self):
- message = unittest_pb2.TestAllTypes()
- test_util.SetAllFields(message)
- self.CompareToGoldenFile(
- self.RemoveRedundantZeros(text_format.MessageToString(message)),
- 'text_format_unittest_data_oneof_implemented.txt')
-
- def testPrintInIndexOrder(self):
- message = unittest_pb2.TestFieldOrderings()
- message.my_string = '115'
- message.my_int = 101
- message.my_float = 111
- message.optional_nested_message.oo = 0
- message.optional_nested_message.bb = 1
- self.CompareToGoldenText(
- self.RemoveRedundantZeros(text_format.MessageToString(
- message, use_index_order=True)),
- 'my_string: \"115\"\nmy_int: 101\nmy_float: 111\n'
- 'optional_nested_message {\n oo: 0\n bb: 1\n}\n')
- self.CompareToGoldenText(
- self.RemoveRedundantZeros(text_format.MessageToString(
- message)),
- 'my_int: 101\nmy_string: \"115\"\nmy_float: 111\n'
- 'optional_nested_message {\n bb: 1\n oo: 0\n}\n')
-
- def testPrintAllExtensions(self):
- message = unittest_pb2.TestAllExtensions()
- test_util.SetAllExtensions(message)
- self.CompareToGoldenFile(
- self.RemoveRedundantZeros(text_format.MessageToString(message)),
- 'text_format_unittest_extensions_data.txt')
-
- def testPrintAllFieldsPointy(self):
- message = unittest_pb2.TestAllTypes()
- test_util.SetAllFields(message)
- self.CompareToGoldenFile(
- self.RemoveRedundantZeros(
- text_format.MessageToString(message, pointy_brackets=True)),
- 'text_format_unittest_data_pointy_oneof.txt')
+ def RemoveRedundantZeros(self, text):
+ # Some platforms print 1e+5 as 1e+005. This is fine, but we need to remove
+ # these zeros in order to match the golden file.
+ text = text.replace('e+0','e+').replace('e+0','e+') \
+ .replace('e-0','e-').replace('e-0','e-')
+ # Floating point fields are printed with .0 suffix even if they are
+ # actualy integer numbers.
+ text = re.compile('\.0$', re.MULTILINE).sub('', text)
+ return text
- def testPrintAllExtensionsPointy(self):
- message = unittest_pb2.TestAllExtensions()
- test_util.SetAllExtensions(message)
- self.CompareToGoldenFile(
- self.RemoveRedundantZeros(text_format.MessageToString(
- message, pointy_brackets=True)),
- 'text_format_unittest_extensions_data_pointy.txt')
- def testPrintMessageSet(self):
- message = unittest_mset_pb2.TestMessageSetContainer()
- ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
- ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
- message.message_set.Extensions[ext1].i = 23
- message.message_set.Extensions[ext2].str = 'foo'
- self.CompareToGoldenText(
- text_format.MessageToString(message),
- 'message_set {\n'
- ' [protobuf_unittest.TestMessageSetExtension1] {\n'
- ' i: 23\n'
- ' }\n'
- ' [protobuf_unittest.TestMessageSetExtension2] {\n'
- ' str: \"foo\"\n'
- ' }\n'
- '}\n')
+@_parameterized.Parameters(
+ (unittest_pb2),
+ (unittest_proto3_arena_pb2))
+class TextFormatTest(TextFormatBase):
- def testPrintExotic(self):
- message = unittest_pb2.TestAllTypes()
+ def testPrintExotic(self, message_module):
+ message = message_module.TestAllTypes()
message.repeated_int64.append(-9223372036854775808)
message.repeated_uint64.append(18446744073709551615)
message.repeated_double.append(123.456)
@@ -143,61 +100,44 @@ class TextFormatTest(unittest.TestCase):
' "\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""\n'
'repeated_string: "\\303\\274\\352\\234\\237"\n')
- def testPrintExoticUnicodeSubclass(self):
- class UnicodeSub(six.text_type):
+ def testPrintExoticUnicodeSubclass(self, message_module):
+ class UnicodeSub(unicode):
pass
- message = unittest_pb2.TestAllTypes()
+ message = message_module.TestAllTypes()
message.repeated_string.append(UnicodeSub(u'\u00fc\ua71f'))
self.CompareToGoldenText(
text_format.MessageToString(message),
'repeated_string: "\\303\\274\\352\\234\\237"\n')
- def testPrintNestedMessageAsOneLine(self):
- message = unittest_pb2.TestAllTypes()
+ def testPrintNestedMessageAsOneLine(self, message_module):
+ message = message_module.TestAllTypes()
msg = message.repeated_nested_message.add()
msg.bb = 42
self.CompareToGoldenText(
text_format.MessageToString(message, as_one_line=True),
'repeated_nested_message { bb: 42 }')
- def testPrintRepeatedFieldsAsOneLine(self):
- message = unittest_pb2.TestAllTypes()
+ def testPrintRepeatedFieldsAsOneLine(self, message_module):
+ message = message_module.TestAllTypes()
message.repeated_int32.append(1)
message.repeated_int32.append(1)
message.repeated_int32.append(3)
- message.repeated_string.append("Google")
- message.repeated_string.append("Zurich")
+ message.repeated_string.append('Google')
+ message.repeated_string.append('Zurich')
self.CompareToGoldenText(
text_format.MessageToString(message, as_one_line=True),
'repeated_int32: 1 repeated_int32: 1 repeated_int32: 3 '
'repeated_string: "Google" repeated_string: "Zurich"')
- def testPrintNestedNewLineInStringAsOneLine(self):
- message = unittest_pb2.TestAllTypes()
- message.optional_string = "a\nnew\nline"
+ def testPrintNestedNewLineInStringAsOneLine(self, message_module):
+ message = message_module.TestAllTypes()
+ message.optional_string = 'a\nnew\nline'
self.CompareToGoldenText(
text_format.MessageToString(message, as_one_line=True),
'optional_string: "a\\nnew\\nline"')
- def testPrintMessageSetAsOneLine(self):
- message = unittest_mset_pb2.TestMessageSetContainer()
- ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
- ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
- message.message_set.Extensions[ext1].i = 23
- message.message_set.Extensions[ext2].str = 'foo'
- self.CompareToGoldenText(
- text_format.MessageToString(message, as_one_line=True),
- 'message_set {'
- ' [protobuf_unittest.TestMessageSetExtension1] {'
- ' i: 23'
- ' }'
- ' [protobuf_unittest.TestMessageSetExtension2] {'
- ' str: \"foo\"'
- ' }'
- ' }')
-
- def testPrintExoticAsOneLine(self):
- message = unittest_pb2.TestAllTypes()
+ def testPrintExoticAsOneLine(self, message_module):
+ message = message_module.TestAllTypes()
message.repeated_int64.append(-9223372036854775808)
message.repeated_uint64.append(18446744073709551615)
message.repeated_double.append(123.456)
@@ -217,8 +157,8 @@ class TextFormatTest(unittest.TestCase):
'"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""'
' repeated_string: "\\303\\274\\352\\234\\237"')
- def testRoundTripExoticAsOneLine(self):
- message = unittest_pb2.TestAllTypes()
+ def testRoundTripExoticAsOneLine(self, message_module):
+ message = message_module.TestAllTypes()
message.repeated_int64.append(-9223372036854775808)
message.repeated_uint64.append(18446744073709551615)
message.repeated_double.append(123.456)
@@ -230,7 +170,7 @@ class TextFormatTest(unittest.TestCase):
# Test as_utf8 = False.
wire_text = text_format.MessageToString(
message, as_one_line=True, as_utf8=False)
- parsed_message = unittest_pb2.TestAllTypes()
+ parsed_message = message_module.TestAllTypes()
r = text_format.Parse(wire_text, parsed_message)
self.assertIs(r, parsed_message)
self.assertEqual(message, parsed_message)
@@ -238,25 +178,25 @@ class TextFormatTest(unittest.TestCase):
# Test as_utf8 = True.
wire_text = text_format.MessageToString(
message, as_one_line=True, as_utf8=True)
- parsed_message = unittest_pb2.TestAllTypes()
+ parsed_message = message_module.TestAllTypes()
r = text_format.Parse(wire_text, parsed_message)
self.assertIs(r, parsed_message)
self.assertEqual(message, parsed_message,
'\n%s != %s' % (message, parsed_message))
- def testPrintRawUtf8String(self):
- message = unittest_pb2.TestAllTypes()
+ def testPrintRawUtf8String(self, message_module):
+ message = message_module.TestAllTypes()
message.repeated_string.append(u'\u00fc\ua71f')
text = text_format.MessageToString(message, as_utf8=True)
self.CompareToGoldenText(text, 'repeated_string: "\303\274\352\234\237"\n')
- parsed_message = unittest_pb2.TestAllTypes()
+ parsed_message = message_module.TestAllTypes()
text_format.Parse(text, parsed_message)
self.assertEqual(message, parsed_message,
'\n%s != %s' % (message, parsed_message))
- def testPrintFloatFormat(self):
+ def testPrintFloatFormat(self, message_module):
# Check that float_format argument is passed to sub-message formatting.
- message = unittest_pb2.NestedTestAllTypes()
+ message = message_module.NestedTestAllTypes()
# We use 1.25 as it is a round number in binary. The proto 32-bit float
# will not gain additional imprecise digits as a 64-bit Python float and
# show up in its str. 32-bit 1.2 is noisy when extended to 64-bit:
@@ -286,85 +226,24 @@ class TextFormatTest(unittest.TestCase):
self.RemoveRedundantZeros(text_message),
'payload {{ {} {} {} {} }}'.format(*formatted_fields))
- def testMessageToString(self):
- message = unittest_pb2.ForeignMessage()
+ def testMessageToString(self, message_module):
+ message = message_module.ForeignMessage()
message.c = 123
self.assertEqual('c: 123\n', str(message))
- def RemoveRedundantZeros(self, text):
- # Some platforms print 1e+5 as 1e+005. This is fine, but we need to remove
- # these zeros in order to match the golden file.
- text = text.replace('e+0','e+').replace('e+0','e+') \
- .replace('e-0','e-').replace('e-0','e-')
- # Floating point fields are printed with .0 suffix even if they are
- # actualy integer numbers.
- text = re.compile('\.0$', re.MULTILINE).sub('', text)
- return text
-
- def testParseGolden(self):
- golden_text = '\n'.join(self.ReadGolden('text_format_unittest_data.txt'))
- parsed_message = unittest_pb2.TestAllTypes()
- r = text_format.Parse(golden_text, parsed_message)
- self.assertIs(r, parsed_message)
-
- message = unittest_pb2.TestAllTypes()
- test_util.SetAllFields(message)
- self.assertEqual(message, parsed_message)
-
- def testParseGoldenExtensions(self):
- golden_text = '\n'.join(self.ReadGolden(
- 'text_format_unittest_extensions_data.txt'))
- parsed_message = unittest_pb2.TestAllExtensions()
- text_format.Parse(golden_text, parsed_message)
-
- message = unittest_pb2.TestAllExtensions()
- test_util.SetAllExtensions(message)
- self.assertEqual(message, parsed_message)
-
- def testParseAllFields(self):
- message = unittest_pb2.TestAllTypes()
+ def testParseAllFields(self, message_module):
+ message = message_module.TestAllTypes()
test_util.SetAllFields(message)
ascii_text = text_format.MessageToString(message)
- parsed_message = unittest_pb2.TestAllTypes()
+ parsed_message = message_module.TestAllTypes()
text_format.Parse(ascii_text, parsed_message)
self.assertEqual(message, parsed_message)
- test_util.ExpectAllFieldsSet(self, message)
+ if message_module is unittest_pb2:
+ test_util.ExpectAllFieldsSet(self, message)
- def testParseAllExtensions(self):
- message = unittest_pb2.TestAllExtensions()
- test_util.SetAllExtensions(message)
- ascii_text = text_format.MessageToString(message)
-
- parsed_message = unittest_pb2.TestAllExtensions()
- text_format.Parse(ascii_text, parsed_message)
- self.assertEqual(message, parsed_message)
-
- def testParseMessageSet(self):
- message = unittest_pb2.TestAllTypes()
- text = ('repeated_uint64: 1\n'
- 'repeated_uint64: 2\n')
- text_format.Parse(text, message)
- self.assertEqual(1, message.repeated_uint64[0])
- self.assertEqual(2, message.repeated_uint64[1])
-
- message = unittest_mset_pb2.TestMessageSetContainer()
- text = ('message_set {\n'
- ' [protobuf_unittest.TestMessageSetExtension1] {\n'
- ' i: 23\n'
- ' }\n'
- ' [protobuf_unittest.TestMessageSetExtension2] {\n'
- ' str: \"foo\"\n'
- ' }\n'
- '}\n')
- text_format.Parse(text, message)
- ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
- ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
- self.assertEqual(23, message.message_set.Extensions[ext1].i)
- self.assertEqual('foo', message.message_set.Extensions[ext2].str)
-
- def testParseExotic(self):
- message = unittest_pb2.TestAllTypes()
+ def testParseExotic(self, message_module):
+ message = message_module.TestAllTypes()
text = ('repeated_int64: -9223372036854775808\n'
'repeated_uint64: 18446744073709551615\n'
'repeated_double: 123.456\n'
@@ -389,8 +268,8 @@ class TextFormatTest(unittest.TestCase):
self.assertEqual(u'\u00fc\ua71f', message.repeated_string[2])
self.assertEqual(u'\u00fc', message.repeated_string[3])
- def testParseTrailingCommas(self):
- message = unittest_pb2.TestAllTypes()
+ def testParseTrailingCommas(self, message_module):
+ message = message_module.TestAllTypes()
text = ('repeated_int64: 100;\n'
'repeated_int64: 200;\n'
'repeated_int64: 300,\n'
@@ -404,101 +283,62 @@ class TextFormatTest(unittest.TestCase):
self.assertEqual(u'one', message.repeated_string[0])
self.assertEqual(u'two', message.repeated_string[1])
- def testParseEmptyText(self):
- message = unittest_pb2.TestAllTypes()
+ def testParseEmptyText(self, message_module):
+ message = message_module.TestAllTypes()
text = ''
text_format.Parse(text, message)
- self.assertEqual(unittest_pb2.TestAllTypes(), message)
+ self.assertEqual(message_module.TestAllTypes(), message)
- def testParseInvalidUtf8(self):
- message = unittest_pb2.TestAllTypes()
+ def testParseInvalidUtf8(self, message_module):
+ message = message_module.TestAllTypes()
text = 'repeated_string: "\\xc3\\xc3"'
self.assertRaises(text_format.ParseError, text_format.Parse, text, message)
- def testParseSingleWord(self):
- message = unittest_pb2.TestAllTypes()
+ def testParseSingleWord(self, message_module):
+ message = message_module.TestAllTypes()
text = 'foo'
- self.assertRaisesWithLiteralMatch(
+ self.assertRaisesRegexp(
text_format.ParseError,
- ('1:1 : Message type "protobuf_unittest.TestAllTypes" has no field named '
- '"foo".'),
+ (r'1:1 : Message type "\w+.TestAllTypes" has no field named '
+ r'"foo".'),
text_format.Parse, text, message)
- def testParseUnknownField(self):
- message = unittest_pb2.TestAllTypes()
+ def testParseUnknownField(self, message_module):
+ message = message_module.TestAllTypes()
text = 'unknown_field: 8\n'
- self.assertRaisesWithLiteralMatch(
- text_format.ParseError,
- ('1:1 : Message type "protobuf_unittest.TestAllTypes" has no field named '
- '"unknown_field".'),
- text_format.Parse, text, message)
-
- def testParseBadExtension(self):
- message = unittest_pb2.TestAllExtensions()
- text = '[unknown_extension]: 8\n'
- self.assertRaisesWithLiteralMatch(
+ self.assertRaisesRegexp(
text_format.ParseError,
- '1:2 : Extension "unknown_extension" not registered.',
- text_format.Parse, text, message)
- message = unittest_pb2.TestAllTypes()
- self.assertRaisesWithLiteralMatch(
- text_format.ParseError,
- ('1:2 : Message type "protobuf_unittest.TestAllTypes" does not have '
- 'extensions.'),
- text_format.Parse, text, message)
-
- def testParseGroupNotClosed(self):
- message = unittest_pb2.TestAllTypes()
- text = 'RepeatedGroup: <'
- self.assertRaisesWithLiteralMatch(
- text_format.ParseError, '1:16 : Expected ">".',
- text_format.Parse, text, message)
-
- text = 'RepeatedGroup: {'
- self.assertRaisesWithLiteralMatch(
- text_format.ParseError, '1:16 : Expected "}".',
+ (r'1:1 : Message type "\w+.TestAllTypes" has no field named '
+ r'"unknown_field".'),
text_format.Parse, text, message)
- def testParseEmptyGroup(self):
- message = unittest_pb2.TestAllTypes()
- text = 'OptionalGroup: {}'
- text_format.Parse(text, message)
- self.assertTrue(message.HasField('optionalgroup'))
-
- message.Clear()
-
- message = unittest_pb2.TestAllTypes()
- text = 'OptionalGroup: <>'
- text_format.Parse(text, message)
- self.assertTrue(message.HasField('optionalgroup'))
-
- def testParseBadEnumValue(self):
- message = unittest_pb2.TestAllTypes()
+ def testParseBadEnumValue(self, message_module):
+ message = message_module.TestAllTypes()
text = 'optional_nested_enum: BARR'
- self.assertRaisesWithLiteralMatch(
+ self.assertRaisesRegexp(
text_format.ParseError,
- ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" '
- 'has no value named BARR.'),
+ (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" '
+ r'has no value named BARR.'),
text_format.Parse, text, message)
- message = unittest_pb2.TestAllTypes()
+ message = message_module.TestAllTypes()
text = 'optional_nested_enum: 100'
- self.assertRaisesWithLiteralMatch(
+ self.assertRaisesRegexp(
text_format.ParseError,
- ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" '
- 'has no value with number 100.'),
+ (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" '
+ r'has no value with number 100.'),
text_format.Parse, text, message)
- def testParseBadIntValue(self):
- message = unittest_pb2.TestAllTypes()
+ def testParseBadIntValue(self, message_module):
+ message = message_module.TestAllTypes()
text = 'optional_int32: bork'
- self.assertRaisesWithLiteralMatch(
+ self.assertRaisesRegexp(
text_format.ParseError,
('1:17 : Couldn\'t parse integer: bork'),
text_format.Parse, text, message)
- def testParseStringFieldUnescape(self):
- message = unittest_pb2.TestAllTypes()
+ def testParseStringFieldUnescape(self, message_module):
+ message = message_module.TestAllTypes()
text = r'''repeated_string: "\xf\x62"
repeated_string: "\\xf\\x62"
repeated_string: "\\\xf\\\x62"
@@ -517,40 +357,254 @@ class TextFormatTest(unittest.TestCase):
message.repeated_string[4])
self.assertEqual(SLASH + 'x20', message.repeated_string[5])
- def testMergeDuplicateScalars(self):
- message = unittest_pb2.TestAllTypes()
+ def testMergeDuplicateScalars(self, message_module):
+ message = message_module.TestAllTypes()
text = ('optional_int32: 42 '
'optional_int32: 67')
r = text_format.Merge(text, message)
self.assertIs(r, message)
self.assertEqual(67, message.optional_int32)
- def testParseDuplicateScalars(self):
- message = unittest_pb2.TestAllTypes()
- text = ('optional_int32: 42 '
- 'optional_int32: 67')
- self.assertRaisesWithLiteralMatch(
- text_format.ParseError,
- ('1:36 : Message type "protobuf_unittest.TestAllTypes" should not '
- 'have multiple "optional_int32" fields.'),
- text_format.Parse, text, message)
-
- def testMergeDuplicateNestedMessageScalars(self):
- message = unittest_pb2.TestAllTypes()
+ def testMergeDuplicateNestedMessageScalars(self, message_module):
+ message = message_module.TestAllTypes()
text = ('optional_nested_message { bb: 1 } '
'optional_nested_message { bb: 2 }')
r = text_format.Merge(text, message)
self.assertTrue(r is message)
self.assertEqual(2, message.optional_nested_message.bb)
- def testParseDuplicateNestedMessageScalars(self):
+ def testParseOneof(self, message_module):
+ m = message_module.TestAllTypes()
+ m.oneof_uint32 = 11
+ m2 = message_module.TestAllTypes()
+ text_format.Parse(text_format.MessageToString(m), m2)
+ self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
+
+
+# These are tests that aren't fundamentally specific to proto2, but are at
+# the moment because of differences between the proto2 and proto3 test schemas.
+# Ideally the schemas would be made more similar so these tests could pass.
+class OnlyWorksWithProto2RightNowTests(TextFormatBase):
+
+ def testPrintAllFieldsPointy(self, message_module):
message = unittest_pb2.TestAllTypes()
- text = ('optional_nested_message { bb: 1 } '
- 'optional_nested_message { bb: 2 }')
- self.assertRaisesWithLiteralMatch(
+ test_util.SetAllFields(message)
+ self.CompareToGoldenFile(
+ self.RemoveRedundantZeros(
+ text_format.MessageToString(message, pointy_brackets=True)),
+ 'text_format_unittest_data_pointy_oneof.txt')
+
+ def testParseGolden(self):
+ golden_text = '\n'.join(self.ReadGolden('text_format_unittest_data.txt'))
+ parsed_message = unittest_pb2.TestAllTypes()
+ r = text_format.Parse(golden_text, parsed_message)
+ self.assertIs(r, parsed_message)
+
+ message = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(message)
+ self.assertEqual(message, parsed_message)
+
+ def testPrintAllFields(self):
+ message = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(message)
+ self.CompareToGoldenFile(
+ self.RemoveRedundantZeros(text_format.MessageToString(message)),
+ 'text_format_unittest_data_oneof_implemented.txt')
+
+ def testPrintAllFieldsPointy(self):
+ message = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(message)
+ self.CompareToGoldenFile(
+ self.RemoveRedundantZeros(
+ text_format.MessageToString(message, pointy_brackets=True)),
+ 'text_format_unittest_data_pointy_oneof.txt')
+
+ def testPrintInIndexOrder(self):
+ message = unittest_pb2.TestFieldOrderings()
+ message.my_string = '115'
+ message.my_int = 101
+ message.my_float = 111
+ message.optional_nested_message.oo = 0
+ message.optional_nested_message.bb = 1
+ self.CompareToGoldenText(
+ self.RemoveRedundantZeros(text_format.MessageToString(
+ message, use_index_order=True)),
+ 'my_string: \"115\"\nmy_int: 101\nmy_float: 111\n'
+ 'optional_nested_message {\n oo: 0\n bb: 1\n}\n')
+ self.CompareToGoldenText(
+ self.RemoveRedundantZeros(text_format.MessageToString(
+ message)),
+ 'my_int: 101\nmy_string: \"115\"\nmy_float: 111\n'
+ 'optional_nested_message {\n bb: 1\n oo: 0\n}\n')
+
+ def testMergeLinesGolden(self):
+ opened = self.ReadGolden('text_format_unittest_data.txt')
+ parsed_message = unittest_pb2.TestAllTypes()
+ r = text_format.MergeLines(opened, parsed_message)
+ self.assertIs(r, parsed_message)
+
+ message = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(message)
+ self.assertEqual(message, parsed_message)
+
+ def testParseLinesGolden(self):
+ opened = self.ReadGolden('text_format_unittest_data.txt')
+ parsed_message = unittest_pb2.TestAllTypes()
+ r = text_format.ParseLines(opened, parsed_message)
+ self.assertIs(r, parsed_message)
+
+ message = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(message)
+ self.assertEqual(message, parsed_message)
+
+ def testPrintMap(self):
+ message = map_unittest_pb2.TestMap()
+
+ message.map_int32_int32[-123] = -456
+ message.map_int64_int64[-2**33] = -2**34
+ message.map_uint32_uint32[123] = 456
+ message.map_uint64_uint64[2**33] = 2**34
+ message.map_string_string["abc"] = "123"
+ message.map_int32_foreign_message[111].c = 5
+
+ # Maps are serialized to text format using their underlying repeated
+ # representation.
+ self.CompareToGoldenText(
+ text_format.MessageToString(message),
+ 'map_int32_int32 {\n'
+ ' key: -123\n'
+ ' value: -456\n'
+ '}\n'
+ 'map_int64_int64 {\n'
+ ' key: -8589934592\n'
+ ' value: -17179869184\n'
+ '}\n'
+ 'map_uint32_uint32 {\n'
+ ' key: 123\n'
+ ' value: 456\n'
+ '}\n'
+ 'map_uint64_uint64 {\n'
+ ' key: 8589934592\n'
+ ' value: 17179869184\n'
+ '}\n'
+ 'map_string_string {\n'
+ ' key: "abc"\n'
+ ' value: "123"\n'
+ '}\n'
+ 'map_int32_foreign_message {\n'
+ ' key: 111\n'
+ ' value {\n'
+ ' c: 5\n'
+ ' }\n'
+ '}\n')
+
+
+# Tests of proto2-only features (MessageSet, extensions, etc.).
+class Proto2Tests(TextFormatBase):
+
+ def testPrintMessageSet(self):
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
+ ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
+ message.message_set.Extensions[ext1].i = 23
+ message.message_set.Extensions[ext2].str = 'foo'
+ self.CompareToGoldenText(
+ text_format.MessageToString(message),
+ 'message_set {\n'
+ ' [protobuf_unittest.TestMessageSetExtension1] {\n'
+ ' i: 23\n'
+ ' }\n'
+ ' [protobuf_unittest.TestMessageSetExtension2] {\n'
+ ' str: \"foo\"\n'
+ ' }\n'
+ '}\n')
+
+ def testPrintMessageSetAsOneLine(self):
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
+ ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
+ message.message_set.Extensions[ext1].i = 23
+ message.message_set.Extensions[ext2].str = 'foo'
+ self.CompareToGoldenText(
+ text_format.MessageToString(message, as_one_line=True),
+ 'message_set {'
+ ' [protobuf_unittest.TestMessageSetExtension1] {'
+ ' i: 23'
+ ' }'
+ ' [protobuf_unittest.TestMessageSetExtension2] {'
+ ' str: \"foo\"'
+ ' }'
+ ' }')
+
+ def testParseMessageSet(self):
+ message = unittest_pb2.TestAllTypes()
+ text = ('repeated_uint64: 1\n'
+ 'repeated_uint64: 2\n')
+ text_format.Parse(text, message)
+ self.assertEqual(1, message.repeated_uint64[0])
+ self.assertEqual(2, message.repeated_uint64[1])
+
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ text = ('message_set {\n'
+ ' [protobuf_unittest.TestMessageSetExtension1] {\n'
+ ' i: 23\n'
+ ' }\n'
+ ' [protobuf_unittest.TestMessageSetExtension2] {\n'
+ ' str: \"foo\"\n'
+ ' }\n'
+ '}\n')
+ text_format.Parse(text, message)
+ ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
+ ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
+ self.assertEqual(23, message.message_set.Extensions[ext1].i)
+ self.assertEqual('foo', message.message_set.Extensions[ext2].str)
+
+ def testPrintAllExtensions(self):
+ message = unittest_pb2.TestAllExtensions()
+ test_util.SetAllExtensions(message)
+ self.CompareToGoldenFile(
+ self.RemoveRedundantZeros(text_format.MessageToString(message)),
+ 'text_format_unittest_extensions_data.txt')
+
+ def testPrintAllExtensionsPointy(self):
+ message = unittest_pb2.TestAllExtensions()
+ test_util.SetAllExtensions(message)
+ self.CompareToGoldenFile(
+ self.RemoveRedundantZeros(text_format.MessageToString(
+ message, pointy_brackets=True)),
+ 'text_format_unittest_extensions_data_pointy.txt')
+
+ def testParseGoldenExtensions(self):
+ golden_text = '\n'.join(self.ReadGolden(
+ 'text_format_unittest_extensions_data.txt'))
+ parsed_message = unittest_pb2.TestAllExtensions()
+ text_format.Parse(golden_text, parsed_message)
+
+ message = unittest_pb2.TestAllExtensions()
+ test_util.SetAllExtensions(message)
+ self.assertEqual(message, parsed_message)
+
+ def testParseAllExtensions(self):
+ message = unittest_pb2.TestAllExtensions()
+ test_util.SetAllExtensions(message)
+ ascii_text = text_format.MessageToString(message)
+
+ parsed_message = unittest_pb2.TestAllExtensions()
+ text_format.Parse(ascii_text, parsed_message)
+ self.assertEqual(message, parsed_message)
+
+ def testParseBadExtension(self):
+ message = unittest_pb2.TestAllExtensions()
+ text = '[unknown_extension]: 8\n'
+ self.assertRaisesRegexp(
text_format.ParseError,
- ('1:65 : Message type "protobuf_unittest.TestAllTypes.NestedMessage" '
- 'should not have multiple "bb" fields.'),
+ '1:2 : Extension "unknown_extension" not registered.',
+ text_format.Parse, text, message)
+ message = unittest_pb2.TestAllTypes()
+ self.assertRaisesRegexp(
+ text_format.ParseError,
+ ('1:2 : Message type "protobuf_unittest.TestAllTypes" does not have '
+ 'extensions.'),
text_format.Parse, text, message)
def testMergeDuplicateExtensionScalars(self):
@@ -566,39 +620,95 @@ class TextFormatTest(unittest.TestCase):
message = unittest_pb2.TestAllExtensions()
text = ('[protobuf_unittest.optional_int32_extension]: 42 '
'[protobuf_unittest.optional_int32_extension]: 67')
- self.assertRaisesWithLiteralMatch(
+ self.assertRaisesRegexp(
text_format.ParseError,
('1:96 : Message type "protobuf_unittest.TestAllExtensions" '
'should not have multiple '
'"protobuf_unittest.optional_int32_extension" extensions.'),
text_format.Parse, text, message)
- def testParseLinesGolden(self):
- opened = self.ReadGolden('text_format_unittest_data.txt')
- parsed_message = unittest_pb2.TestAllTypes()
- r = text_format.ParseLines(opened, parsed_message)
- self.assertIs(r, parsed_message)
+ def testParseDuplicateNestedMessageScalars(self):
+ message = unittest_pb2.TestAllTypes()
+ text = ('optional_nested_message { bb: 1 } '
+ 'optional_nested_message { bb: 2 }')
+ self.assertRaisesRegexp(
+ text_format.ParseError,
+ ('1:65 : Message type "protobuf_unittest.TestAllTypes.NestedMessage" '
+ 'should not have multiple "bb" fields.'),
+ text_format.Parse, text, message)
+ def testParseDuplicateScalars(self):
message = unittest_pb2.TestAllTypes()
- test_util.SetAllFields(message)
- self.assertEqual(message, parsed_message)
+ text = ('optional_int32: 42 '
+ 'optional_int32: 67')
+ self.assertRaisesRegexp(
+ text_format.ParseError,
+ ('1:36 : Message type "protobuf_unittest.TestAllTypes" should not '
+ 'have multiple "optional_int32" fields.'),
+ text_format.Parse, text, message)
- def testMergeLinesGolden(self):
- opened = self.ReadGolden('text_format_unittest_data.txt')
- parsed_message = unittest_pb2.TestAllTypes()
- r = text_format.MergeLines(opened, parsed_message)
- self.assertIs(r, parsed_message)
+ def testParseGroupNotClosed(self):
+ message = unittest_pb2.TestAllTypes()
+ text = 'RepeatedGroup: <'
+ self.assertRaisesRegexp(
+ text_format.ParseError, '1:16 : Expected ">".',
+ text_format.Parse, text, message)
+ text = 'RepeatedGroup: {'
+ self.assertRaisesRegexp(
+ text_format.ParseError, '1:16 : Expected "}".',
+ text_format.Parse, text, message)
+ def testParseEmptyGroup(self):
message = unittest_pb2.TestAllTypes()
- test_util.SetAllFields(message)
- self.assertEqual(message, parsed_message)
+ text = 'OptionalGroup: {}'
+ text_format.Parse(text, message)
+ self.assertTrue(message.HasField('optionalgroup'))
- def testParseOneof(self):
- m = unittest_pb2.TestAllTypes()
- m.oneof_uint32 = 11
- m2 = unittest_pb2.TestAllTypes()
- text_format.Parse(text_format.MessageToString(m), m2)
- self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
+ message.Clear()
+
+ message = unittest_pb2.TestAllTypes()
+ text = 'OptionalGroup: <>'
+ text_format.Parse(text, message)
+ self.assertTrue(message.HasField('optionalgroup'))
+
+ # Maps aren't really proto2-only, but our test schema only has maps for
+ # proto2.
+ def testParseMap(self):
+ text = ('map_int32_int32 {\n'
+ ' key: -123\n'
+ ' value: -456\n'
+ '}\n'
+ 'map_int64_int64 {\n'
+ ' key: -8589934592\n'
+ ' value: -17179869184\n'
+ '}\n'
+ 'map_uint32_uint32 {\n'
+ ' key: 123\n'
+ ' value: 456\n'
+ '}\n'
+ 'map_uint64_uint64 {\n'
+ ' key: 8589934592\n'
+ ' value: 17179869184\n'
+ '}\n'
+ 'map_string_string {\n'
+ ' key: "abc"\n'
+ ' value: "123"\n'
+ '}\n'
+ 'map_int32_foreign_message {\n'
+ ' key: 111\n'
+ ' value {\n'
+ ' c: 5\n'
+ ' }\n'
+ '}\n')
+ message = map_unittest_pb2.TestMap()
+ text_format.Parse(text, message)
+
+ self.assertEqual(-456, message.map_int32_int32[-123])
+ self.assertEqual(-2**34, message.map_int64_int64[-2**33])
+ self.assertEqual(456, message.map_uint32_uint32[123])
+ self.assertEqual(2**34, message.map_uint64_uint64[2**33])
+ self.assertEqual("123", message.map_string_string["abc"])
+ self.assertEqual(5, message.map_int32_foreign_message[111].c)
class TokenizerTest(unittest.TestCase):
diff --git a/python/google/protobuf/internal/type_checkers.py b/python/google/protobuf/internal/type_checkers.py
index 8d10fbe0..363018ed 100755
--- a/python/google/protobuf/internal/type_checkers.py
+++ b/python/google/protobuf/internal/type_checkers.py
@@ -56,6 +56,8 @@ from google.protobuf import descriptor
_FieldDescriptor = descriptor.FieldDescriptor
+def SupportsOpenEnums(field_descriptor):
+ return field_descriptor.containing_type.syntax == "proto3"
def GetTypeChecker(field):
"""Returns a type checker for a message field of the specified types.
@@ -71,7 +73,11 @@ def GetTypeChecker(field):
field.type == _FieldDescriptor.TYPE_STRING):
return UnicodeValueChecker()
if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
- return EnumValueChecker(field.enum_type)
+ if SupportsOpenEnums(field):
+ # When open enums are supported, any int32 can be assigned.
+ return _VALUE_CHECKERS[_FieldDescriptor.CPPTYPE_INT32]
+ else:
+ return EnumValueChecker(field.enum_type)
return _VALUE_CHECKERS[field.cpp_type]
@@ -120,6 +126,9 @@ class IntValueChecker(object):
proposed_value = self._TYPE(proposed_value)
return proposed_value
+ def DefaultValue(self):
+ return 0
+
class EnumValueChecker(object):
@@ -137,6 +146,9 @@ class EnumValueChecker(object):
raise ValueError('Unknown enum value: %d' % proposed_value)
return proposed_value
+ def DefaultValue(self):
+ return self._enum_type.values[0].number
+
class UnicodeValueChecker(object):
@@ -162,6 +174,9 @@ class UnicodeValueChecker(object):
(proposed_value))
return proposed_value
+ def DefaultValue(self):
+ return u""
+
class Int32ValueChecker(IntValueChecker):
# We're sure to use ints instead of longs here since comparison may be more
diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py
index e405f113..5cd23d78 100755
--- a/python/google/protobuf/internal/unknown_fields_test.py
+++ b/python/google/protobuf/internal/unknown_fields_test.py
@@ -1,4 +1,4 @@
-#! /usr/bin/python
+#! /usr/bin/env python
# -*- coding: utf-8 -*-
#
# Protocol Buffers - Google's data interchange format
@@ -50,6 +50,7 @@ except ImportError:
from google.protobuf import unittest_mset_pb2
from google.protobuf import unittest_pb2
+from google.protobuf import unittest_proto3_arena_pb2
from google.protobuf.internal import api_implementation
from google.protobuf.internal import encoder
from google.protobuf.internal import missing_enum_values_pb2
@@ -57,10 +58,81 @@ from google.protobuf.internal import test_util
from google.protobuf.internal import type_checkers
+class UnknownFieldsTest(unittest.TestCase):
+
+ def setUp(self):
+ self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
+ self.all_fields = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(self.all_fields)
+ self.all_fields_data = self.all_fields.SerializeToString()
+ self.empty_message = unittest_pb2.TestEmptyMessage()
+ self.empty_message.ParseFromString(self.all_fields_data)
+
+ def testSerialize(self):
+ data = self.empty_message.SerializeToString()
+
+ # Don't use assertEqual because we don't want to dump raw binary data to
+ # stdout.
+ self.assertTrue(data == self.all_fields_data)
+
+ def testSerializeProto3(self):
+ # Verify that proto3 doesn't preserve unknown fields.
+ message = unittest_proto3_arena_pb2.TestEmptyMessage()
+ message.ParseFromString(self.all_fields_data)
+ self.assertEqual(0, len(message.SerializeToString()))
+
+ def testByteSize(self):
+ self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize())
+
+ def testListFields(self):
+ # Make sure ListFields doesn't return unknown fields.
+ self.assertEqual(0, len(self.empty_message.ListFields()))
+
+ def testSerializeMessageSetWireFormatUnknownExtension(self):
+ # Create a message using the message set wire format with an unknown
+ # message.
+ raw = unittest_mset_pb2.RawMessageSet()
+
+ # Add an unknown extension.
+ item = raw.item.add()
+ item.type_id = 1545009
+ message1 = unittest_mset_pb2.TestMessageSetExtension1()
+ message1.i = 12345
+ item.message = message1.SerializeToString()
+
+ serialized = raw.SerializeToString()
+
+ # Parse message using the message set wire format.
+ proto = unittest_mset_pb2.TestMessageSet()
+ proto.MergeFromString(serialized)
+
+ # Verify that the unknown extension is serialized unchanged
+ reserialized = proto.SerializeToString()
+ new_raw = unittest_mset_pb2.RawMessageSet()
+ new_raw.MergeFromString(reserialized)
+ self.assertEqual(raw, new_raw)
+
+ # C++ implementation for proto2 does not currently take into account unknown
+ # fields when checking equality.
+ #
+ # TODO(haberman): fix this.
+ @unittest.skipIf(
+ api_implementation.Type() == 'cpp' and api_implementation.Version() == 2,
+ 'C++ implementation does not expose unknown fields to Python')
+ def testEquals(self):
+ message = unittest_pb2.TestEmptyMessage()
+ message.ParseFromString(self.all_fields_data)
+ self.assertEqual(self.empty_message, message)
+
+ self.all_fields.ClearField('optional_string')
+ message.ParseFromString(self.all_fields.SerializeToString())
+ self.assertNotEqual(self.empty_message, message)
+
+
@skipIf(
api_implementation.Type() == 'cpp' and api_implementation.Version() == 2,
'C++ implementation does not expose unknown fields to Python')
-class UnknownFieldsTest(unittest.TestCase):
+class UnknownFieldsAccessorsTest(unittest.TestCase):
def setUp(self):
self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
@@ -110,13 +182,6 @@ class UnknownFieldsTest(unittest.TestCase):
value = self.GetField('optionalgroup')
self.assertEqual(self.all_fields.optionalgroup, value)
- def testSerialize(self):
- data = self.empty_message.SerializeToString()
-
- # Don't use assertEqual because we don't want to dump raw binary data to
- # stdout.
- self.assertTrue(data == self.all_fields_data)
-
def testCopyFrom(self):
message = unittest_pb2.TestEmptyMessage()
message.CopyFrom(self.empty_message)
@@ -144,51 +209,12 @@ class UnknownFieldsTest(unittest.TestCase):
self.empty_message.Clear()
self.assertEqual(0, len(self.empty_message._unknown_fields))
- def testByteSize(self):
- self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize())
-
def testUnknownExtensions(self):
message = unittest_pb2.TestEmptyMessageWithExtensions()
message.ParseFromString(self.all_fields_data)
self.assertEqual(self.empty_message._unknown_fields,
message._unknown_fields)
- def testListFields(self):
- # Make sure ListFields doesn't return unknown fields.
- self.assertEqual(0, len(self.empty_message.ListFields()))
-
- def testSerializeMessageSetWireFormatUnknownExtension(self):
- # Create a message using the message set wire format with an unknown
- # message.
- raw = unittest_mset_pb2.RawMessageSet()
-
- # Add an unknown extension.
- item = raw.item.add()
- item.type_id = 1545009
- message1 = unittest_mset_pb2.TestMessageSetExtension1()
- message1.i = 12345
- item.message = message1.SerializeToString()
-
- serialized = raw.SerializeToString()
-
- # Parse message using the message set wire format.
- proto = unittest_mset_pb2.TestMessageSet()
- proto.MergeFromString(serialized)
-
- # Verify that the unknown extension is serialized unchanged
- reserialized = proto.SerializeToString()
- new_raw = unittest_mset_pb2.RawMessageSet()
- new_raw.MergeFromString(reserialized)
- self.assertEqual(raw, new_raw)
-
- def testEquals(self):
- message = unittest_pb2.TestEmptyMessage()
- message.ParseFromString(self.all_fields_data)
- self.assertEqual(self.empty_message, message)
-
- self.all_fields.ClearField('optional_string')
- message.ParseFromString(self.all_fields.SerializeToString())
- self.assertNotEqual(self.empty_message, message)
@skipIf(
diff --git a/python/google/protobuf/internal/wire_format_test.py b/python/google/protobuf/internal/wire_format_test.py
index e40a40cc..78dc1167 100755
--- a/python/google/protobuf/internal/wire_format_test.py
+++ b/python/google/protobuf/internal/wire_format_test.py
@@ -1,4 +1,4 @@
-#! /usr/bin/python
+#! /usr/bin/env python
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
@@ -35,7 +35,6 @@
__author__ = 'robinson@google.com (Will Robinson)'
import unittest
-
from google.protobuf import message
from google.protobuf.internal import wire_format