aboutsummaryrefslogtreecommitdiffhomepage
path: root/python/google/protobuf/internal
diff options
context:
space:
mode:
authorGravatar Feng Xiao <xfxyjwf@gmail.com>2015-12-11 17:09:20 -0800
committerGravatar Feng Xiao <xfxyjwf@gmail.com>2015-12-11 17:10:28 -0800
commite841bac4fcf47f809e089a70d5f84ac37b3883df (patch)
treed25dc5fc814db182c04c5f276ff1a609c5965a5a /python/google/protobuf/internal
parent99a6a95c751a28a3cc33dd2384959179f83f682c (diff)
Down-integrate from internal code base.
Diffstat (limited to 'python/google/protobuf/internal')
-rw-r--r--python/google/protobuf/internal/any_test.proto42
-rwxr-xr-xpython/google/protobuf/internal/containers.py23
-rw-r--r--python/google/protobuf/internal/descriptor_pool_test.py176
-rwxr-xr-xpython/google/protobuf/internal/descriptor_test.py20
-rw-r--r--python/google/protobuf/internal/json_format_test.py58
-rw-r--r--python/google/protobuf/internal/message_factory_test.py7
-rw-r--r--python/google/protobuf/internal/message_set_extensions.proto8
-rwxr-xr-xpython/google/protobuf/internal/message_test.py81
-rwxr-xr-xpython/google/protobuf/internal/python_message.py45
-rwxr-xr-xpython/google/protobuf/internal/reflection_test.py18
-rwxr-xr-xpython/google/protobuf/internal/text_format_test.py170
-rw-r--r--python/google/protobuf/internal/well_known_types.py622
-rw-r--r--python/google/protobuf/internal/well_known_types_test.py509
13 files changed, 1701 insertions, 78 deletions
diff --git a/python/google/protobuf/internal/any_test.proto b/python/google/protobuf/internal/any_test.proto
new file mode 100644
index 00000000..cd641ca0
--- /dev/null
+++ b/python/google/protobuf/internal/any_test.proto
@@ -0,0 +1,42 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: jieluo@google.com (Jie Luo)
+
+syntax = "proto3";
+
+package google.protobuf.internal;
+
+import "google/protobuf/any.proto";
+
+message TestAny {
+ google.protobuf.Any value = 1;
+ int32 int_value = 2;
+}
diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py
index 9c8275eb..97cdd848 100755
--- a/python/google/protobuf/internal/containers.py
+++ b/python/google/protobuf/internal/containers.py
@@ -464,6 +464,9 @@ class ScalarMap(MutableMapping):
return val
def __contains__(self, item):
+ # We check the key's type to match the strong-typing flavor of the API.
+ # Also this makes it easier to match the behavior of the C++ implementation.
+ self._key_checker.CheckValue(item)
return item in self._values
# We need to override this explicitly, because our defaultdict-like behavior
@@ -491,10 +494,20 @@ class ScalarMap(MutableMapping):
def __iter__(self):
return iter(self._values)
+ def __repr__(self):
+ return repr(self._values)
+
def MergeFrom(self, other):
self._values.update(other._values)
self._message_listener.Modified()
+ def InvalidateIterators(self):
+ # It appears that the only way to reliably invalidate iterators to
+ # self._values is to ensure that its size changes.
+ original = self._values
+ self._values = original.copy()
+ original[None] = None
+
# This is defined in the abstract base, but we can do it much more cheaply.
def clear(self):
self._values.clear()
@@ -576,12 +589,22 @@ class MessageMap(MutableMapping):
def __iter__(self):
return iter(self._values)
+ def __repr__(self):
+ return repr(self._values)
+
def MergeFrom(self, other):
for key in other:
self[key].MergeFrom(other[key])
# self._message_listener.Modified() not required here, because
# mutations to submessages already propagate.
+ def InvalidateIterators(self):
+ # It appears that the only way to reliably invalidate iterators to
+ # self._values is to ensure that its size changes.
+ original = self._values
+ self._values = original.copy()
+ original[None] = None
+
# This is defined in the abstract base, but we can do it much more cheaply.
def clear(self):
self._values.clear()
diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py
index da9a78db..f1d6bf99 100644
--- a/python/google/protobuf/internal/descriptor_pool_test.py
+++ b/python/google/protobuf/internal/descriptor_pool_test.py
@@ -40,6 +40,8 @@ try:
import unittest2 as unittest
except ImportError:
import unittest
+from google.protobuf import unittest_import_pb2
+from google.protobuf import unittest_import_public_pb2
from google.protobuf import unittest_pb2
from google.protobuf import descriptor_pb2
from google.protobuf.internal import api_implementation
@@ -51,13 +53,17 @@ from google.protobuf.internal import test_util
from google.protobuf import descriptor
from google.protobuf import descriptor_database
from google.protobuf import descriptor_pool
+from google.protobuf import message_factory
from google.protobuf import symbol_database
class DescriptorPoolTest(unittest.TestCase):
+ def CreatePool(self):
+ return descriptor_pool.DescriptorPool()
+
def setUp(self):
- self.pool = descriptor_pool.DescriptorPool()
+ self.pool = self.CreatePool()
self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString(
factory_test1_pb2.DESCRIPTOR.serialized_pb)
self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString(
@@ -89,7 +95,7 @@ class DescriptorPoolTest(unittest.TestCase):
'google.protobuf.python.internal.Factory1Message')
self.assertIsInstance(file_desc1, descriptor.FileDescriptor)
self.assertEqual('google/protobuf/internal/factory_test1.proto',
- file_desc1.name)
+ file_desc1.name)
self.assertEqual('google.protobuf.python.internal', file_desc1.package)
self.assertIn('Factory1Message', file_desc1.message_types_by_name)
@@ -97,7 +103,7 @@ class DescriptorPoolTest(unittest.TestCase):
'google.protobuf.python.internal.Factory2Message')
self.assertIsInstance(file_desc2, descriptor.FileDescriptor)
self.assertEqual('google/protobuf/internal/factory_test2.proto',
- file_desc2.name)
+ file_desc2.name)
self.assertEqual('google.protobuf.python.internal', file_desc2.package)
self.assertIn('Factory2Message', file_desc2.message_types_by_name)
@@ -111,7 +117,7 @@ class DescriptorPoolTest(unittest.TestCase):
self.assertIsInstance(msg1, descriptor.Descriptor)
self.assertEqual('Factory1Message', msg1.name)
self.assertEqual('google.protobuf.python.internal.Factory1Message',
- msg1.full_name)
+ msg1.full_name)
self.assertEqual(None, msg1.containing_type)
nested_msg1 = msg1.nested_types[0]
@@ -132,7 +138,7 @@ class DescriptorPoolTest(unittest.TestCase):
self.assertIsInstance(msg2, descriptor.Descriptor)
self.assertEqual('Factory2Message', msg2.name)
self.assertEqual('google.protobuf.python.internal.Factory2Message',
- msg2.full_name)
+ msg2.full_name)
self.assertIsNone(msg2.containing_type)
nested_msg2 = msg2.nested_types[0]
@@ -223,6 +229,37 @@ class DescriptorPoolTest(unittest.TestCase):
with self.assertRaises(KeyError):
self.pool.FindEnumTypeByName('Does not exist')
+ def testFindFieldByName(self):
+ field = self.pool.FindFieldByName(
+ 'google.protobuf.python.internal.Factory1Message.list_value')
+ self.assertEqual(field.name, 'list_value')
+ self.assertEqual(field.label, field.LABEL_REPEATED)
+ with self.assertRaises(KeyError):
+ self.pool.FindFieldByName('Does not exist')
+
+ def testFindExtensionByName(self):
+ # An extension defined in a message.
+ extension = self.pool.FindExtensionByName(
+ 'google.protobuf.python.internal.Factory2Message.one_more_field')
+ self.assertEqual(extension.name, 'one_more_field')
+ # An extension defined at file scope.
+ extension = self.pool.FindExtensionByName(
+ 'google.protobuf.python.internal.another_field')
+ self.assertEqual(extension.name, 'another_field')
+ self.assertEqual(extension.number, 1002)
+ with self.assertRaises(KeyError):
+ self.pool.FindFieldByName('Does not exist')
+
+ def testExtensionsAreNotFields(self):
+ with self.assertRaises(KeyError):
+ self.pool.FindFieldByName('google.protobuf.python.internal.another_field')
+ with self.assertRaises(KeyError):
+ self.pool.FindFieldByName(
+ 'google.protobuf.python.internal.Factory2Message.one_more_field')
+ with self.assertRaises(KeyError):
+ self.pool.FindExtensionByName(
+ 'google.protobuf.python.internal.Factory1Message.list_value')
+
def testUserDefinedDB(self):
db = descriptor_database.DescriptorDatabase()
self.pool = descriptor_pool.DescriptorPool(db)
@@ -231,8 +268,7 @@ class DescriptorPoolTest(unittest.TestCase):
self.testFindMessageTypeByName()
def testAddSerializedFile(self):
- db = descriptor_database.DescriptorDatabase()
- self.pool = descriptor_pool.DescriptorPool(db)
+ self.pool = descriptor_pool.DescriptorPool()
self.pool.AddSerializedFile(self.factory_test1_fd.SerializeToString())
self.pool.AddSerializedFile(self.factory_test2_fd.SerializeToString())
self.testFindMessageTypeByName()
@@ -274,6 +310,56 @@ class DescriptorPoolTest(unittest.TestCase):
'google/protobuf/internal/descriptor_pool_test1.proto')
_CheckDefaultValue(file_descriptor)
+ def testDefaultValueForCustomMessages(self):
+ """Check the value returned by non-existent fields."""
+ def _CheckValueAndType(value, expected_value, expected_type):
+ self.assertEqual(value, expected_value)
+ self.assertIsInstance(value, expected_type)
+
+ def _CheckDefaultValues(msg):
+ try:
+ int64 = long
+ except NameError: # Python3
+ int64 = int
+ try:
+ unicode_type = unicode
+ except NameError: # Python3
+ unicode_type = str
+ _CheckValueAndType(msg.optional_int32, 0, int)
+ _CheckValueAndType(msg.optional_uint64, 0, (int64, int))
+ _CheckValueAndType(msg.optional_float, 0, (float, int))
+ _CheckValueAndType(msg.optional_double, 0, (float, int))
+ _CheckValueAndType(msg.optional_bool, False, bool)
+ _CheckValueAndType(msg.optional_string, u'', unicode_type)
+ _CheckValueAndType(msg.optional_bytes, b'', bytes)
+ _CheckValueAndType(msg.optional_nested_enum, msg.FOO, int)
+ # First for the generated message
+ _CheckDefaultValues(unittest_pb2.TestAllTypes())
+ # Then for a message built with from the DescriptorPool.
+ pool = descriptor_pool.DescriptorPool()
+ pool.Add(descriptor_pb2.FileDescriptorProto.FromString(
+ unittest_import_public_pb2.DESCRIPTOR.serialized_pb))
+ pool.Add(descriptor_pb2.FileDescriptorProto.FromString(
+ unittest_import_pb2.DESCRIPTOR.serialized_pb))
+ pool.Add(descriptor_pb2.FileDescriptorProto.FromString(
+ unittest_pb2.DESCRIPTOR.serialized_pb))
+ message_class = message_factory.MessageFactory(pool).GetPrototype(
+ pool.FindMessageTypeByName(
+ unittest_pb2.TestAllTypes.DESCRIPTOR.full_name))
+ _CheckDefaultValues(message_class())
+
+
+@unittest.skipIf(api_implementation.Type() != 'cpp',
+ 'explicit tests of the C++ implementation')
+class CppDescriptorPoolTest(DescriptorPoolTest):
+ # TODO(amauryfa): remove when descriptor_pool.DescriptorPool() creates true
+ # C++ descriptor pool object for C++ implementation.
+
+ def CreatePool(self):
+ # pylint: disable=g-import-not-at-top
+ from google.protobuf.pyext import _message
+ return _message.DescriptorPool()
+
class ProtoFile(object):
@@ -468,6 +554,8 @@ class AddDescriptorTest(unittest.TestCase):
pool.FindFileContainingSymbol(
prefix + 'protobuf_unittest.TestAllTypes.NestedMessage').name)
+ @unittest.skipIf(api_implementation.Type() == 'cpp',
+ 'With the cpp implementation, Add() must be called first')
def testMessage(self):
self._TestMessage('')
self._TestMessage('.')
@@ -502,10 +590,14 @@ class AddDescriptorTest(unittest.TestCase):
pool.FindFileContainingSymbol(
prefix + 'protobuf_unittest.TestAllTypes.NestedEnum').name)
+ @unittest.skipIf(api_implementation.Type() == 'cpp',
+ 'With the cpp implementation, Add() must be called first')
def testEnum(self):
self._TestEnum('')
self._TestEnum('.')
+ @unittest.skipIf(api_implementation.Type() == 'cpp',
+ 'With the cpp implementation, Add() must be called first')
def testFile(self):
pool = descriptor_pool.DescriptorPool()
pool.AddFileDescriptor(unittest_pb2.DESCRIPTOR)
@@ -520,6 +612,76 @@ class AddDescriptorTest(unittest.TestCase):
pool.FindFileContainingSymbol(
'protobuf_unittest.TestAllTypes')
+ def _GetDescriptorPoolClass(self):
+ # Test with both implementations of descriptor pools.
+ if api_implementation.Type() == 'cpp':
+ # pylint: disable=g-import-not-at-top
+ from google.protobuf.pyext import _message
+ return _message.DescriptorPool
+ else:
+ return descriptor_pool.DescriptorPool
+
+ def testEmptyDescriptorPool(self):
+ # Check that an empty DescriptorPool() contains no message.
+ pool = self._GetDescriptorPoolClass()()
+ proto_file_name = descriptor_pb2.DESCRIPTOR.name
+ self.assertRaises(KeyError, pool.FindFileByName, proto_file_name)
+ # Add the above file to the pool
+ file_descriptor = descriptor_pb2.FileDescriptorProto()
+ descriptor_pb2.DESCRIPTOR.CopyToProto(file_descriptor)
+ pool.Add(file_descriptor)
+ # Now it exists.
+ self.assertTrue(pool.FindFileByName(proto_file_name))
+
+ def testCustomDescriptorPool(self):
+ # Create a new pool, and add a file descriptor.
+ pool = self._GetDescriptorPoolClass()()
+ file_desc = descriptor_pb2.FileDescriptorProto(
+ name='some/file.proto', package='package')
+ file_desc.message_type.add(name='Message')
+ pool.Add(file_desc)
+ self.assertEqual(pool.FindFileByName('some/file.proto').name,
+ 'some/file.proto')
+ self.assertEqual(pool.FindMessageTypeByName('package.Message').name,
+ 'Message')
+
+
+@unittest.skipIf(
+ api_implementation.Type() != 'cpp',
+ 'default_pool is only supported by the C++ implementation')
+class DefaultPoolTest(unittest.TestCase):
+
+ def testFindMethods(self):
+ # pylint: disable=g-import-not-at-top
+ from google.protobuf.pyext import _message
+ pool = _message.default_pool
+ self.assertIs(
+ pool.FindFileByName('google/protobuf/unittest.proto'),
+ unittest_pb2.DESCRIPTOR)
+ self.assertIs(
+ pool.FindMessageTypeByName('protobuf_unittest.TestAllTypes'),
+ unittest_pb2.TestAllTypes.DESCRIPTOR)
+ self.assertIs(
+ pool.FindFieldByName('protobuf_unittest.TestAllTypes.optional_int32'),
+ unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name['optional_int32'])
+ self.assertIs(
+ pool.FindExtensionByName('protobuf_unittest.optional_int32_extension'),
+ unittest_pb2.DESCRIPTOR.extensions_by_name['optional_int32_extension'])
+ self.assertIs(
+ pool.FindEnumTypeByName('protobuf_unittest.ForeignEnum'),
+ unittest_pb2.ForeignEnum.DESCRIPTOR)
+ self.assertIs(
+ pool.FindOneofByName('protobuf_unittest.TestAllTypes.oneof_field'),
+ unittest_pb2.TestAllTypes.DESCRIPTOR.oneofs_by_name['oneof_field'])
+
+ def testAddFileDescriptor(self):
+ # pylint: disable=g-import-not-at-top
+ from google.protobuf.pyext import _message
+ pool = _message.default_pool
+ file_desc = descriptor_pb2.FileDescriptorProto(name='some/file.proto')
+ pool.Add(file_desc)
+ pool.AddSerializedFile(file_desc.SerializeToString())
+
TEST1_FILE = ProtoFile(
'google/protobuf/internal/descriptor_pool_test1.proto',
diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py
index 99afee63..fee09a56 100755
--- a/python/google/protobuf/internal/descriptor_test.py
+++ b/python/google/protobuf/internal/descriptor_test.py
@@ -47,6 +47,7 @@ from google.protobuf import descriptor_pb2
from google.protobuf.internal import api_implementation
from google.protobuf.internal import test_util
from google.protobuf import descriptor
+from google.protobuf import descriptor_pool
from google.protobuf import symbol_database
from google.protobuf import text_format
@@ -75,9 +76,9 @@ class DescriptorTest(unittest.TestCase):
enum_proto.value.add(name='FOREIGN_BAR', number=5)
enum_proto.value.add(name='FOREIGN_BAZ', number=6)
- descriptor_pool = symbol_database.Default().pool
- descriptor_pool.Add(file_proto)
- self.my_file = descriptor_pool.FindFileByName(file_proto.name)
+ self.pool = self.GetDescriptorPool()
+ self.pool.Add(file_proto)
+ self.my_file = self.pool.FindFileByName(file_proto.name)
self.my_message = self.my_file.message_types_by_name[message_proto.name]
self.my_enum = self.my_message.enum_types_by_name[enum_proto.name]
@@ -97,6 +98,9 @@ class DescriptorTest(unittest.TestCase):
self.my_method
])
+ def GetDescriptorPool(self):
+ return symbol_database.Default().pool
+
def testEnumValueName(self):
self.assertEqual(self.my_message.EnumValueName('ForeignEnum', 4),
'FOREIGN_FOO')
@@ -393,6 +397,9 @@ class DescriptorTest(unittest.TestCase):
def testFileDescriptor(self):
self.assertEqual(self.my_file.name, 'some/filename/some.proto')
self.assertEqual(self.my_file.package, 'protobuf_unittest')
+ self.assertEqual(self.my_file.pool, self.pool)
+ # Generated modules also belong to the default pool.
+ self.assertEqual(unittest_pb2.DESCRIPTOR.pool, descriptor_pool.Default())
@unittest.skipIf(
api_implementation.Type() != 'cpp' or api_implementation.Version() != 2,
@@ -407,6 +414,13 @@ class DescriptorTest(unittest.TestCase):
message_descriptor.fields.append(None)
+class NewDescriptorTest(DescriptorTest):
+ """Redo the same tests as above, but with a separate DescriptorPool."""
+
+ def GetDescriptorPool(self):
+ return descriptor_pool.DescriptorPool()
+
+
class GeneratedDescriptorTest(unittest.TestCase):
"""Tests for the properties of descriptors in generated code."""
diff --git a/python/google/protobuf/internal/json_format_test.py b/python/google/protobuf/internal/json_format_test.py
index 69197865..be3ad11a 100644
--- a/python/google/protobuf/internal/json_format_test.py
+++ b/python/google/protobuf/internal/json_format_test.py
@@ -42,6 +42,7 @@ try:
import unittest2 as unittest
except ImportError:
import unittest
+from google.protobuf.internal import well_known_types
from google.protobuf import json_format
from google.protobuf.util import json_format_proto3_pb2
@@ -269,15 +270,15 @@ class JsonFormatTest(JsonFormatBase):
'}'))
parsed_message = json_format_proto3_pb2.TestTimestamp()
self.CheckParseBack(message, parsed_message)
- text = (r'{"value": "1972-01-01T01:00:00.01+08:00",'
+ text = (r'{"value": "1970-01-01T00:00:00.01+08:00",'
r'"repeatedValue":['
- r' "1972-01-01T01:00:00.01+08:30",'
- r' "1972-01-01T01:00:00.01-01:23"]}')
+ r' "1970-01-01T00:00:00.01+08:30",'
+ r' "1970-01-01T00:00:00.01-01:23"]}')
json_format.Parse(text, parsed_message)
- self.assertEqual(parsed_message.value.seconds, 63104400)
+ self.assertEqual(parsed_message.value.seconds, -8 * 3600)
self.assertEqual(parsed_message.value.nanos, 10000000)
- self.assertEqual(parsed_message.repeated_value[0].seconds, 63106200)
- self.assertEqual(parsed_message.repeated_value[1].seconds, 63070620)
+ self.assertEqual(parsed_message.repeated_value[0].seconds, -8.5 * 3600)
+ self.assertEqual(parsed_message.repeated_value[1].seconds, 3600 + 23 * 60)
def testDurationMessage(self):
message = json_format_proto3_pb2.TestDuration()
@@ -389,7 +390,7 @@ class JsonFormatTest(JsonFormatBase):
def testParseEmptyText(self):
self.CheckError('',
- r'Failed to load JSON: (Expecting value)|(No JSON)')
+ r'Failed to load JSON: (Expecting value)|(No JSON).')
def testParseBadEnumValue(self):
self.CheckError(
@@ -414,7 +415,7 @@ class JsonFormatTest(JsonFormatBase):
if sys.version_info < (2, 7):
return
self.CheckError('{"int32Value": 1,\n"int32Value":2}',
- 'Failed to load JSON: duplicate key int32Value')
+ 'Failed to load JSON: duplicate key int32Value.')
def testInvalidBoolValue(self):
self.CheckError('{"boolValue": 1}',
@@ -431,39 +432,43 @@ class JsonFormatTest(JsonFormatBase):
json_format.Parse, text, message)
self.CheckError('{"int32Value": 012345}',
(r'Failed to load JSON: Expecting \'?,\'? delimiter: '
- r'line 1'))
+ r'line 1.'))
self.CheckError('{"int32Value": 1.0}',
'Failed to parse int32Value field: '
- 'Couldn\'t parse integer: 1.0')
+ 'Couldn\'t parse integer: 1.0.')
self.CheckError('{"int32Value": " 1 "}',
'Failed to parse int32Value field: '
- 'Couldn\'t parse integer: " 1 "')
+ 'Couldn\'t parse integer: " 1 ".')
+ self.CheckError('{"int32Value": "1 "}',
+ 'Failed to parse int32Value field: '
+ 'Couldn\'t parse integer: "1 ".')
self.CheckError('{"int32Value": 12345678901234567890}',
'Failed to parse int32Value field: Value out of range: '
- '12345678901234567890')
+ '12345678901234567890.')
self.CheckError('{"int32Value": 1e5}',
'Failed to parse int32Value field: '
- 'Couldn\'t parse integer: 100000.0')
+ 'Couldn\'t parse integer: 100000.0.')
self.CheckError('{"uint32Value": -1}',
- 'Failed to parse uint32Value field: Value out of range: -1')
+ 'Failed to parse uint32Value field: '
+ 'Value out of range: -1.')
def testInvalidFloatValue(self):
self.CheckError('{"floatValue": "nan"}',
'Failed to parse floatValue field: Couldn\'t '
- 'parse float "nan", use "NaN" instead')
+ 'parse float "nan", use "NaN" instead.')
def testInvalidBytesValue(self):
self.CheckError('{"bytesValue": "AQI"}',
- 'Failed to parse bytesValue field: Incorrect padding')
+ 'Failed to parse bytesValue field: Incorrect padding.')
self.CheckError('{"bytesValue": "AQI*"}',
- 'Failed to parse bytesValue field: Incorrect padding')
+ 'Failed to parse bytesValue field: Incorrect padding.')
def testInvalidMap(self):
message = json_format_proto3_pb2.TestMap()
text = '{"int32Map": {"null": 2, "2": 3}}'
self.assertRaisesRegexp(
json_format.ParseError,
- 'Failed to parse int32Map field: Couldn\'t parse integer: "null"',
+ 'Failed to parse int32Map field: invalid literal',
json_format.Parse, text, message)
text = '{"int32Map": {1: 2, "2": 3}}'
self.assertRaisesRegexp(
@@ -474,7 +479,7 @@ class JsonFormatTest(JsonFormatBase):
text = '{"boolMap": {"null": 1}}'
self.assertRaisesRegexp(
json_format.ParseError,
- 'Failed to parse boolMap field: Expect "true" or "false", not null.',
+ 'Failed to parse boolMap field: Expected "true" or "false", not null.',
json_format.Parse, text, message)
if sys.version_info < (2, 7):
return
@@ -490,30 +495,29 @@ class JsonFormatTest(JsonFormatBase):
self.assertRaisesRegexp(
json_format.ParseError,
'time data \'10000-01-01T00:00:00\' does not match'
- ' format \'%Y-%m-%dT%H:%M:%S\'',
+ ' format \'%Y-%m-%dT%H:%M:%S\'.',
json_format.Parse, text, message)
text = '{"value": "1970-01-01T00:00:00.0123456789012Z"}'
self.assertRaisesRegexp(
- json_format.ParseError,
- 'Failed to parse value field: Failed to parse Timestamp: '
+ well_known_types.ParseError,
'nanos 0123456789012 more than 9 fractional digits.',
json_format.Parse, text, message)
text = '{"value": "1972-01-01T01:00:00.01+08"}'
self.assertRaisesRegexp(
- json_format.ParseError,
- (r'Failed to parse value field: Invalid timezone offset value: \+08'),
+ well_known_types.ParseError,
+ (r'Invalid timezone offset value: \+08.'),
json_format.Parse, text, message)
# Time smaller than minimum time.
text = '{"value": "0000-01-01T00:00:00Z"}'
self.assertRaisesRegexp(
json_format.ParseError,
- 'Failed to parse value field: year is out of range',
+ 'Failed to parse value field: year is out of range.',
json_format.Parse, text, message)
# Time bigger than maxinum time.
message.value.seconds = 253402300800
self.assertRaisesRegexp(
- json_format.SerializeToJsonError,
- 'Failed to serialize value field: year is out of range',
+ OverflowError,
+ 'date value out of range',
json_format.MessageToJson, message)
def testInvalidOneof(self):
diff --git a/python/google/protobuf/internal/message_factory_test.py b/python/google/protobuf/internal/message_factory_test.py
index d760b898..2fbe5ea7 100644
--- a/python/google/protobuf/internal/message_factory_test.py
+++ b/python/google/protobuf/internal/message_factory_test.py
@@ -45,6 +45,7 @@ from google.protobuf import descriptor_database
from google.protobuf import descriptor_pool
from google.protobuf import message_factory
+
class MessageFactoryTest(unittest.TestCase):
def setUp(self):
@@ -104,8 +105,8 @@ class MessageFactoryTest(unittest.TestCase):
def testGetMessages(self):
# performed twice because multiple calls with the same input must be allowed
for _ in range(2):
- messages = message_factory.GetMessages([self.factory_test2_fd,
- self.factory_test1_fd])
+ messages = message_factory.GetMessages([self.factory_test1_fd,
+ self.factory_test2_fd])
self.assertTrue(
set(['google.protobuf.python.internal.Factory2Message',
'google.protobuf.python.internal.Factory1Message'],
@@ -116,7 +117,7 @@ class MessageFactoryTest(unittest.TestCase):
set(['google.protobuf.python.internal.Factory2Message.one_more_field',
'google.protobuf.python.internal.another_field'],
).issubset(
- set(messages['google.protobuf.python.internal.Factory1Message']
+ set(messages['google.protobuf.python.internal.Factory1Message']
._extensions_by_name.keys())))
factory_msg1 = messages['google.protobuf.python.internal.Factory1Message']
msg1 = messages['google.protobuf.python.internal.Factory1Message']()
diff --git a/python/google/protobuf/internal/message_set_extensions.proto b/python/google/protobuf/internal/message_set_extensions.proto
index 702c8d07..14e5f193 100644
--- a/python/google/protobuf/internal/message_set_extensions.proto
+++ b/python/google/protobuf/internal/message_set_extensions.proto
@@ -54,6 +54,14 @@ message TestMessageSetExtension2 {
optional string str = 25;
}
+message TestMessageSetExtension3 {
+ optional string text = 35;
+}
+
+extend TestMessageSet {
+ optional TestMessageSetExtension3 message_set_extension3 = 98418655;
+}
+
// This message was used to generate
// //net/proto2/python/internal/testdata/message_set_message, but is commented
// out since it must not actually exist in code, to simulate an "unknown"
diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py
index 13c3caa6..d03f2d25 100755
--- a/python/google/protobuf/internal/message_test.py
+++ b/python/google/protobuf/internal/message_test.py
@@ -60,6 +60,7 @@ from google.protobuf.internal import _parameterized
from google.protobuf import map_unittest_pb2
from google.protobuf import unittest_pb2
from google.protobuf import unittest_proto3_arena_pb2
+from google.protobuf.internal import any_test_pb2
from google.protobuf.internal import api_implementation
from google.protobuf.internal import packed_field_test_pb2
from google.protobuf.internal import test_util
@@ -1279,12 +1280,13 @@ class Proto3Test(unittest.TestCase):
self.assertIsInstance(msg.map_string_string['abc'], six.text_type)
- # Accessing an unset key still throws TypeError of the type of the key
+ # Accessing an unset key still throws TypeError if the type of the key
# is incorrect.
with self.assertRaises(TypeError):
msg.map_string_string[123]
- self.assertFalse(123 in msg.map_string_string)
+ with self.assertRaises(TypeError):
+ 123 in msg.map_string_string
def testMapGet(self):
# Need to test that get() properly returns the default, even though the dict
@@ -1591,31 +1593,49 @@ class Proto3Test(unittest.TestCase):
# For the C++ implementation this tests the correctness of
# ScalarMapContainer::Release()
msg = map_unittest_pb2.TestMap()
- map = msg.map_int32_int32
+ int32_map = msg.map_int32_int32
- map[2] = 4
- map[3] = 6
- map[4] = 8
+ int32_map[2] = 4
+ int32_map[3] = 6
+ int32_map[4] = 8
msg.ClearField('map_int32_int32')
+ self.assertEqual(b'', msg.SerializeToString())
matching_dict = {2: 4, 3: 6, 4: 8}
- self.assertMapIterEquals(map.items(), matching_dict)
+ self.assertMapIterEquals(int32_map.items(), matching_dict)
- def testMapIterValidAfterFieldCleared(self):
- # Map iterator needs to work even if field is cleared.
+ def testMessageMapValidAfterFieldCleared(self):
+ # Map needs to work even if field is cleared.
# For the C++ implementation this tests the correctness of
# ScalarMapContainer::Release()
msg = map_unittest_pb2.TestMap()
+ int32_foreign_message = msg.map_int32_foreign_message
- msg.map_int32_int32[2] = 4
- msg.map_int32_int32[3] = 6
- msg.map_int32_int32[4] = 8
+ int32_foreign_message[2].c = 5
- it = msg.map_int32_int32.items()
+ msg.ClearField('map_int32_foreign_message')
+ self.assertEqual(b'', msg.SerializeToString())
+ self.assertTrue(2 in int32_foreign_message.keys())
+
+ def testMapIterInvalidatedByClearField(self):
+ # Map iterator is invalidated when field is cleared.
+ # But this case does need to not crash the interpreter.
+ # For the C++ implementation this tests the correctness of
+ # ScalarMapContainer::Release()
+ msg = map_unittest_pb2.TestMap()
+
+ it = iter(msg.map_int32_int32)
msg.ClearField('map_int32_int32')
- matching_dict = {2: 4, 3: 6, 4: 8}
- self.assertMapIterEquals(it, matching_dict)
+ with self.assertRaises(RuntimeError):
+ for _ in it:
+ pass
+
+ it = iter(msg.map_int32_foreign_message)
+ msg.ClearField('map_int32_foreign_message')
+ with self.assertRaises(RuntimeError):
+ for _ in it:
+ pass
def testMapDelete(self):
msg = map_unittest_pb2.TestMap()
@@ -1646,6 +1666,37 @@ class Proto3Test(unittest.TestCase):
msg.map_string_foreign_message['foo'].c = 5
self.assertEqual(0, len(msg.FindInitializationErrors()))
+ def testAnyMessage(self):
+ # Creates and sets message.
+ msg = any_test_pb2.TestAny()
+ msg_descriptor = msg.DESCRIPTOR
+ all_types = unittest_pb2.TestAllTypes()
+ all_descriptor = all_types.DESCRIPTOR
+ all_types.repeated_string.append(u'\u00fc\ua71f')
+ # Packs to Any.
+ msg.value.Pack(all_types)
+ self.assertEqual(msg.value.type_url,
+ 'type.googleapis.com/%s' % all_descriptor.full_name)
+ self.assertEqual(msg.value.value,
+ all_types.SerializeToString())
+ # Tests Is() method.
+ self.assertTrue(msg.value.Is(all_descriptor))
+ self.assertFalse(msg.value.Is(msg_descriptor))
+ # Unpacks Any.
+ unpacked_message = unittest_pb2.TestAllTypes()
+ self.assertTrue(msg.value.Unpack(unpacked_message))
+ self.assertEqual(all_types, unpacked_message)
+ # Unpacks to different type.
+ self.assertFalse(msg.value.Unpack(msg))
+ # Only Any messages have Pack method.
+ try:
+ msg.Pack(all_types)
+ except AttributeError:
+ pass
+ else:
+ raise AttributeError('%s should not have Pack method.' %
+ msg_descriptor.full_name)
+
class ValidTypeNamesTest(unittest.TestCase):
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py
index 2b87f704..87f60666 100755
--- a/python/google/protobuf/internal/python_message.py
+++ b/python/google/protobuf/internal/python_message.py
@@ -65,6 +65,7 @@ from google.protobuf.internal import encoder
from google.protobuf.internal import enum_type_wrapper
from google.protobuf.internal import message_listener as message_listener_mod
from google.protobuf.internal import type_checkers
+from google.protobuf.internal import well_known_types
from google.protobuf.internal import wire_format
from google.protobuf import descriptor as descriptor_mod
from google.protobuf import message as message_mod
@@ -72,6 +73,7 @@ from google.protobuf import symbol_database
from google.protobuf import text_format
_FieldDescriptor = descriptor_mod.FieldDescriptor
+_AnyFullTypeName = 'google.protobuf.Any'
class GeneratedProtocolMessageType(type):
@@ -127,6 +129,8 @@ class GeneratedProtocolMessageType(type):
Newly-allocated class.
"""
descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
+ if descriptor.full_name in well_known_types.WKTBASES:
+ bases += (well_known_types.WKTBASES[descriptor.full_name],)
_AddClassAttributesForNestedExtensions(descriptor, dictionary)
_AddSlots(descriptor, dictionary)
@@ -261,7 +265,6 @@ def _IsMessageSetExtension(field):
field.containing_type.has_options and
field.containing_type.GetOptions().message_set_wire_format and
field.type == _FieldDescriptor.TYPE_MESSAGE and
- field.message_type == field.extension_scope and
field.label == _FieldDescriptor.LABEL_OPTIONAL)
@@ -543,7 +546,8 @@ def _GetFieldByName(message_descriptor, field_name):
try:
return message_descriptor.fields_by_name[field_name]
except KeyError:
- raise ValueError('Protocol message has no "%s" field.' % field_name)
+ raise ValueError('Protocol message %s has no "%s" field.' %
+ (message_descriptor.name, field_name))
def _AddPropertiesForFields(descriptor, cls):
@@ -848,9 +852,15 @@ def _AddClearFieldMethod(message_descriptor, cls):
else:
return
except KeyError:
- raise ValueError('Protocol message has no "%s" field.' % field_name)
+ raise ValueError('Protocol message %s() has no "%s" field.' %
+ (message_descriptor.name, field_name))
if field in self._fields:
+ # To match the C++ implementation, we need to invalidate iterators
+ # for map fields when ClearField() happens.
+ if hasattr(self._fields[field], 'InvalidateIterators'):
+ self._fields[field].InvalidateIterators()
+
# Note: If the field is a sub-message, its listener will still point
# at us. That's fine, because the worst than can happen is that it
# will call _Modified() and invalidate our byte size. Big deal.
@@ -904,7 +914,19 @@ def _AddHasExtensionMethod(cls):
return extension_handle in self._fields
cls.HasExtension = HasExtension
-def _UnpackAny(msg):
+def _InternalUnpackAny(msg):
+ """Unpacks Any message and returns the unpacked message.
+
+ This internal method is differnt from public Any Unpack method which takes
+ the target message as argument. _InternalUnpackAny method does not have
+ target message type and need to find the message type in descriptor pool.
+
+ Args:
+ msg: An Any message to be unpacked.
+
+ Returns:
+ The unpacked message.
+ """
type_url = msg.type_url
db = symbol_database.Default()
@@ -935,9 +957,9 @@ def _AddEqualsMethod(message_descriptor, cls):
if self is other:
return True
- if self.DESCRIPTOR.full_name == "google.protobuf.Any":
- any_a = _UnpackAny(self)
- any_b = _UnpackAny(other)
+ if self.DESCRIPTOR.full_name == _AnyFullTypeName:
+ any_a = _InternalUnpackAny(self)
+ any_b = _InternalUnpackAny(other)
if any_a and any_b:
return any_a == any_b
@@ -962,6 +984,13 @@ def _AddStrMethod(message_descriptor, cls):
cls.__str__ = __str__
+def _AddReprMethod(message_descriptor, cls):
+ """Helper for _AddMessageMethods()."""
+ def __repr__(self):
+ return text_format.MessageToString(self)
+ cls.__repr__ = __repr__
+
+
def _AddUnicodeMethod(unused_message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
@@ -1270,6 +1299,7 @@ def _AddMessageMethods(message_descriptor, cls):
_AddClearMethod(message_descriptor, cls)
_AddEqualsMethod(message_descriptor, cls)
_AddStrMethod(message_descriptor, cls)
+ _AddReprMethod(message_descriptor, cls)
_AddUnicodeMethod(message_descriptor, cls)
_AddSetListenerMethod(cls)
_AddByteSizeMethod(message_descriptor, cls)
@@ -1280,6 +1310,7 @@ def _AddMessageMethods(message_descriptor, cls):
_AddMergeFromMethod(cls)
_AddWhichOneofMethod(message_descriptor, cls)
+
def _AddPrivateHelperMethods(message_descriptor, cls):
"""Adds implementation of private helper methods to cls."""
diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py
index 8621d61e..752f2f5d 100755
--- a/python/google/protobuf/internal/reflection_test.py
+++ b/python/google/protobuf/internal/reflection_test.py
@@ -2394,8 +2394,10 @@ class SerializationTest(unittest.TestCase):
extension_message2 = message_set_extensions_pb2.TestMessageSetExtension2
extension1 = extension_message1.message_set_extension
extension2 = extension_message2.message_set_extension
+ extension3 = message_set_extensions_pb2.message_set_extension3
proto.Extensions[extension1].i = 123
proto.Extensions[extension2].str = 'foo'
+ proto.Extensions[extension3].text = 'bar'
# Serialize using the MessageSet wire format (this is specified in the
# .proto file).
@@ -2407,7 +2409,7 @@ class SerializationTest(unittest.TestCase):
self.assertEqual(
len(serialized),
raw.MergeFromString(serialized))
- self.assertEqual(2, len(raw.item))
+ self.assertEqual(3, len(raw.item))
message1 = message_set_extensions_pb2.TestMessageSetExtension1()
self.assertEqual(
@@ -2421,6 +2423,12 @@ class SerializationTest(unittest.TestCase):
message2.MergeFromString(raw.item[1].message))
self.assertEqual('foo', message2.str)
+ message3 = message_set_extensions_pb2.TestMessageSetExtension3()
+ self.assertEqual(
+ len(raw.item[2].message),
+ message3.MergeFromString(raw.item[2].message))
+ self.assertEqual('bar', message3.text)
+
# Deserialize using the MessageSet wire format.
proto2 = message_set_extensions_pb2.TestMessageSet()
self.assertEqual(
@@ -2428,6 +2436,7 @@ class SerializationTest(unittest.TestCase):
proto2.MergeFromString(serialized))
self.assertEqual(123, proto2.Extensions[extension1].i)
self.assertEqual('foo', proto2.Extensions[extension2].str)
+ self.assertEqual('bar', proto2.Extensions[extension3].text)
# Check byte size.
self.assertEqual(proto2.ByteSize(), len(serialized))
@@ -2757,9 +2766,10 @@ class SerializationTest(unittest.TestCase):
def testInitArgsUnknownFieldName(self):
def InitalizeEmptyMessageWithExtraKeywordArg():
unused_proto = unittest_pb2.TestEmptyMessage(unknown='unknown')
- self._CheckRaises(ValueError,
- InitalizeEmptyMessageWithExtraKeywordArg,
- 'Protocol message has no "unknown" field.')
+ self._CheckRaises(
+ ValueError,
+ InitalizeEmptyMessageWithExtraKeywordArg,
+ 'Protocol message TestEmptyMessage has no "unknown" field.')
def testInitRequiredKwargs(self):
proto = unittest_pb2.TestRequired(a=1, b=1, c=1)
diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py
index cca0ee63..0e14556c 100755
--- a/python/google/protobuf/internal/text_format_test.py
+++ b/python/google/protobuf/internal/text_format_test.py
@@ -51,8 +51,22 @@ from google.protobuf import unittest_pb2
from google.protobuf import unittest_proto3_arena_pb2
from google.protobuf.internal import api_implementation
from google.protobuf.internal import test_util
+from google.protobuf.internal import message_set_extensions_pb2
from google.protobuf import text_format
+
+# Low-level nuts-n-bolts tests.
+class SimpleTextFormatTests(unittest.TestCase):
+
+ # The members of _QUOTES are formatted into a regexp template that
+ # expects single characters. Therefore it's an error (in addition to being
+ # non-sensical in the first place) to try to specify a "quote mark" that is
+ # more than one character.
+ def TestQuoteMarksAreSingleChars(self):
+ for quote in text_format._QUOTES:
+ self.assertEqual(1, len(quote))
+
+
# Base class with some common functionality.
class TextFormatBase(unittest.TestCase):
@@ -287,6 +301,19 @@ class TextFormatTest(TextFormatBase):
self.assertEqual(u'one', message.repeated_string[0])
self.assertEqual(u'two', message.repeated_string[1])
+ def testParseRepeatedScalarShortFormat(self, message_module):
+ message = message_module.TestAllTypes()
+ text = ('repeated_int64: [100, 200];\n'
+ 'repeated_int64: 300,\n'
+ 'repeated_string: ["one", "two"];\n')
+ text_format.Parse(text, message)
+
+ self.assertEqual(100, message.repeated_int64[0])
+ self.assertEqual(200, message.repeated_int64[1])
+ self.assertEqual(300, message.repeated_int64[2])
+ self.assertEqual(u'one', message.repeated_string[0])
+ self.assertEqual(u'two', message.repeated_string[1])
+
def testParseEmptyText(self, message_module):
message = message_module.TestAllTypes()
text = ''
@@ -301,7 +328,7 @@ class TextFormatTest(TextFormatBase):
def testParseSingleWord(self, message_module):
message = message_module.TestAllTypes()
text = 'foo'
- six.assertRaisesRegex(self,
+ six.assertRaisesRegex(self,
text_format.ParseError,
(r'1:1 : Message type "\w+.TestAllTypes" has no field named '
r'"foo".'),
@@ -310,7 +337,7 @@ class TextFormatTest(TextFormatBase):
def testParseUnknownField(self, message_module):
message = message_module.TestAllTypes()
text = 'unknown_field: 8\n'
- six.assertRaisesRegex(self,
+ six.assertRaisesRegex(self,
text_format.ParseError,
(r'1:1 : Message type "\w+.TestAllTypes" has no field named '
r'"unknown_field".'),
@@ -319,7 +346,7 @@ class TextFormatTest(TextFormatBase):
def testParseBadEnumValue(self, message_module):
message = message_module.TestAllTypes()
text = 'optional_nested_enum: BARR'
- six.assertRaisesRegex(self,
+ six.assertRaisesRegex(self,
text_format.ParseError,
(r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" '
r'has no value named BARR.'),
@@ -327,7 +354,7 @@ class TextFormatTest(TextFormatBase):
message = message_module.TestAllTypes()
text = 'optional_nested_enum: 100'
- six.assertRaisesRegex(self,
+ six.assertRaisesRegex(self,
text_format.ParseError,
(r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" '
r'has no value with number 100.'),
@@ -336,7 +363,7 @@ class TextFormatTest(TextFormatBase):
def testParseBadIntValue(self, message_module):
message = message_module.TestAllTypes()
text = 'optional_int32: bork'
- six.assertRaisesRegex(self,
+ six.assertRaisesRegex(self,
text_format.ParseError,
('1:17 : Couldn\'t parse integer: bork'),
text_format.Parse, text, message)
@@ -553,6 +580,15 @@ class Proto2Tests(TextFormatBase):
' }\n'
'}\n')
+ message = message_set_extensions_pb2.TestMessageSet()
+ ext = message_set_extensions_pb2.message_set_extension3
+ message.Extensions[ext].text = 'bar'
+ self.CompareToGoldenText(
+ text_format.MessageToString(message),
+ '[google.protobuf.internal.TestMessageSetExtension3] {\n'
+ ' text: \"bar\"\n'
+ '}\n')
+
def testPrintMessageSetAsOneLine(self):
message = unittest_mset_pb2.TestMessageSetContainer()
ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
@@ -627,15 +663,125 @@ class Proto2Tests(TextFormatBase):
text_format.Parse(ascii_text, parsed_message)
self.assertEqual(message, parsed_message)
+ def testParseAllowedUnknownExtension(self):
+ # Skip over unknown extension correctly.
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ text = ('message_set {\n'
+ ' [unknown_extension] {\n'
+ ' i: 23\n'
+ ' [nested_unknown_ext]: {\n'
+ ' i: 23\n'
+ ' test: "test_string"\n'
+ ' floaty_float: -0.315\n'
+ ' num: -inf\n'
+ ' multiline_str: "abc"\n'
+ ' "def"\n'
+ ' "xyz."\n'
+ ' [nested_unknown_ext]: <\n'
+ ' i: 23\n'
+ ' i: 24\n'
+ ' pointfloat: .3\n'
+ ' test: "test_string"\n'
+ ' floaty_float: -0.315\n'
+ ' num: -inf\n'
+ ' long_string: "test" "test2" \n'
+ ' >\n'
+ ' }\n'
+ ' }\n'
+ ' [unknown_extension]: 5\n'
+ '}\n')
+ text_format.Parse(text, message, allow_unknown_extension=True)
+ golden = 'message_set {\n}\n'
+ self.CompareToGoldenText(text_format.MessageToString(message), golden)
+
+ # Catch parse errors in unknown extension.
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ malformed = ('message_set {\n'
+ ' [unknown_extension] {\n'
+ ' i:\n' # Missing value.
+ ' }\n'
+ '}\n')
+ six.assertRaisesRegex(self,
+ text_format.ParseError,
+ 'Invalid field value: }',
+ text_format.Parse, malformed, message,
+ allow_unknown_extension=True)
+
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ malformed = ('message_set {\n'
+ ' [unknown_extension] {\n'
+ ' str: "malformed string\n' # Missing closing quote.
+ ' }\n'
+ '}\n')
+ six.assertRaisesRegex(self,
+ text_format.ParseError,
+ 'Invalid field value: "',
+ text_format.Parse, malformed, message,
+ allow_unknown_extension=True)
+
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ malformed = ('message_set {\n'
+ ' [unknown_extension] {\n'
+ ' str: "malformed\n multiline\n string\n'
+ ' }\n'
+ '}\n')
+ six.assertRaisesRegex(self,
+ text_format.ParseError,
+ 'Invalid field value: "',
+ text_format.Parse, malformed, message,
+ allow_unknown_extension=True)
+
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ malformed = ('message_set {\n'
+ ' [malformed_extension] <\n'
+ ' i: -5\n'
+ ' \n' # Missing '>' here.
+ '}\n')
+ six.assertRaisesRegex(self,
+ text_format.ParseError,
+ '5:1 : Expected ">".',
+ text_format.Parse, malformed, message,
+ allow_unknown_extension=True)
+
+ # Don't allow unknown fields with allow_unknown_extension=True.
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ malformed = ('message_set {\n'
+ ' unknown_field: true\n'
+ ' \n' # Missing '>' here.
+ '}\n')
+ six.assertRaisesRegex(self,
+ text_format.ParseError,
+ ('2:3 : Message type '
+ '"proto2_wireformat_unittest.TestMessageSet" has no'
+ ' field named "unknown_field".'),
+ text_format.Parse, malformed, message,
+ allow_unknown_extension=True)
+
+ # Parse known extension correcty.
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ text = ('message_set {\n'
+ ' [protobuf_unittest.TestMessageSetExtension1] {\n'
+ ' i: 23\n'
+ ' }\n'
+ ' [protobuf_unittest.TestMessageSetExtension2] {\n'
+ ' str: \"foo\"\n'
+ ' }\n'
+ '}\n')
+ text_format.Parse(text, message, allow_unknown_extension=True)
+ ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
+ ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
+ self.assertEqual(23, message.message_set.Extensions[ext1].i)
+ self.assertEqual('foo', message.message_set.Extensions[ext2].str)
+
def testParseBadExtension(self):
message = unittest_pb2.TestAllExtensions()
text = '[unknown_extension]: 8\n'
- six.assertRaisesRegex(self,
+ six.assertRaisesRegex(self,
text_format.ParseError,
'1:2 : Extension "unknown_extension" not registered.',
text_format.Parse, text, message)
message = unittest_pb2.TestAllTypes()
- six.assertRaisesRegex(self,
+ six.assertRaisesRegex(self,
text_format.ParseError,
('1:2 : Message type "protobuf_unittest.TestAllTypes" does not have '
'extensions.'),
@@ -654,7 +800,7 @@ class Proto2Tests(TextFormatBase):
message = unittest_pb2.TestAllExtensions()
text = ('[protobuf_unittest.optional_int32_extension]: 42 '
'[protobuf_unittest.optional_int32_extension]: 67')
- six.assertRaisesRegex(self,
+ six.assertRaisesRegex(self,
text_format.ParseError,
('1:96 : Message type "protobuf_unittest.TestAllExtensions" '
'should not have multiple '
@@ -665,7 +811,7 @@ class Proto2Tests(TextFormatBase):
message = unittest_pb2.TestAllTypes()
text = ('optional_nested_message { bb: 1 } '
'optional_nested_message { bb: 2 }')
- six.assertRaisesRegex(self,
+ six.assertRaisesRegex(self,
text_format.ParseError,
('1:65 : Message type "protobuf_unittest.TestAllTypes.NestedMessage" '
'should not have multiple "bb" fields.'),
@@ -675,7 +821,7 @@ class Proto2Tests(TextFormatBase):
message = unittest_pb2.TestAllTypes()
text = ('optional_int32: 42 '
'optional_int32: 67')
- six.assertRaisesRegex(self,
+ six.assertRaisesRegex(self,
text_format.ParseError,
('1:36 : Message type "protobuf_unittest.TestAllTypes" should not '
'have multiple "optional_int32" fields.'),
@@ -684,11 +830,11 @@ class Proto2Tests(TextFormatBase):
def testParseGroupNotClosed(self):
message = unittest_pb2.TestAllTypes()
text = 'RepeatedGroup: <'
- six.assertRaisesRegex(self,
+ six.assertRaisesRegex(self,
text_format.ParseError, '1:16 : Expected ">".',
text_format.Parse, text, message)
text = 'RepeatedGroup: {'
- six.assertRaisesRegex(self,
+ six.assertRaisesRegex(self,
text_format.ParseError, '1:16 : Expected "}".',
text_format.Parse, text, message)
diff --git a/python/google/protobuf/internal/well_known_types.py b/python/google/protobuf/internal/well_known_types.py
new file mode 100644
index 00000000..d3de9831
--- /dev/null
+++ b/python/google/protobuf/internal/well_known_types.py
@@ -0,0 +1,622 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# https://developers.google.com/protocol-buffers/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Contains well known classes.
+
+This files defines well known classes which need extra maintenance including:
+ - Any
+ - Duration
+ - FieldMask
+ - Timestamp
+"""
+
+__author__ = 'jieluo@google.com (Jie Luo)'
+
+from datetime import datetime
+from datetime import timedelta
+
+from google.protobuf.descriptor import FieldDescriptor
+
+_TIMESTAMPFOMAT = '%Y-%m-%dT%H:%M:%S'
+_NANOS_PER_SECOND = 1000000000
+_NANOS_PER_MILLISECOND = 1000000
+_NANOS_PER_MICROSECOND = 1000
+_MILLIS_PER_SECOND = 1000
+_MICROS_PER_SECOND = 1000000
+_SECONDS_PER_DAY = 24 * 3600
+
+
+class Error(Exception):
+ """Top-level module error."""
+
+
+class ParseError(Error):
+ """Thrown in case of parsing error."""
+
+
+class Any(object):
+ """Class for Any Message type."""
+
+ def Pack(self, msg):
+ """Packs the specified message into current Any message."""
+ self.type_url = 'type.googleapis.com/%s' % msg.DESCRIPTOR.full_name
+ self.value = msg.SerializeToString()
+
+ def Unpack(self, msg):
+ """Unpacks the current Any message into specified message."""
+ descriptor = msg.DESCRIPTOR
+ if not self.Is(descriptor):
+ return False
+ msg.ParseFromString(self.value)
+ return True
+
+ def Is(self, descriptor):
+ """Checks if this Any represents the given protobuf type."""
+ # Only last part is to be used: b/25630112
+ return self.type_url.split('/')[-1] == descriptor.full_name
+
+
+class Timestamp(object):
+ """Class for Timestamp message type."""
+
+ def ToJsonString(self):
+ """Converts Timestamp to RFC 3339 date string format.
+
+ Returns:
+ A string converted from timestamp. The string is always Z-normalized
+ and uses 3, 6 or 9 fractional digits as required to represent the
+ exact time. Example of the return format: '1972-01-01T10:00:20.021Z'
+ """
+ nanos = self.nanos % _NANOS_PER_SECOND
+ total_sec = self.seconds + (self.nanos - nanos) // _NANOS_PER_SECOND
+ seconds = total_sec % _SECONDS_PER_DAY
+ days = (total_sec - seconds) // _SECONDS_PER_DAY
+ dt = datetime(1970, 1, 1) + timedelta(days, seconds)
+
+ result = dt.isoformat()
+ if (nanos % 1e9) == 0:
+ # If there are 0 fractional digits, the fractional
+ # point '.' should be omitted when serializing.
+ return result + 'Z'
+ if (nanos % 1e6) == 0:
+ # Serialize 3 fractional digits.
+ return result + '.%03dZ' % (nanos / 1e6)
+ if (nanos % 1e3) == 0:
+ # Serialize 6 fractional digits.
+ return result + '.%06dZ' % (nanos / 1e3)
+ # Serialize 9 fractional digits.
+ return result + '.%09dZ' % nanos
+
+ def FromJsonString(self, value):
+ """Parse a RFC 3339 date string format to Timestamp.
+
+ Args:
+ value: A date string. Any fractional digits (or none) and any offset are
+ accepted as long as they fit into nano-seconds precision.
+ Example of accepted format: '1972-01-01T10:00:20.021-05:00'
+
+ Raises:
+ ParseError: On parsing problems.
+ """
+ timezone_offset = value.find('Z')
+ if timezone_offset == -1:
+ timezone_offset = value.find('+')
+ if timezone_offset == -1:
+ timezone_offset = value.rfind('-')
+ if timezone_offset == -1:
+ raise ParseError(
+ 'Failed to parse timestamp: missing valid timezone offset.')
+ time_value = value[0:timezone_offset]
+ # Parse datetime and nanos.
+ point_position = time_value.find('.')
+ if point_position == -1:
+ second_value = time_value
+ nano_value = ''
+ else:
+ second_value = time_value[:point_position]
+ nano_value = time_value[point_position + 1:]
+ date_object = datetime.strptime(second_value, _TIMESTAMPFOMAT)
+ td = date_object - datetime(1970, 1, 1)
+ seconds = td.seconds + td.days * _SECONDS_PER_DAY
+ if len(nano_value) > 9:
+ raise ParseError(
+ 'Failed to parse Timestamp: nanos {0} more than '
+ '9 fractional digits.'.format(nano_value))
+ if nano_value:
+ nanos = round(float('0.' + nano_value) * 1e9)
+ else:
+ nanos = 0
+ # Parse timezone offsets.
+ if value[timezone_offset] == 'Z':
+ if len(value) != timezone_offset + 1:
+ raise ParseError('Failed to parse timestamp: invalid trailing'
+ ' data {0}.'.format(value))
+ else:
+ timezone = value[timezone_offset:]
+ pos = timezone.find(':')
+ if pos == -1:
+ raise ParseError(
+ 'Invalid timezone offset value: {0}.'.format(timezone))
+ if timezone[0] == '+':
+ seconds -= (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60
+ else:
+ seconds += (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60
+ # Set seconds and nanos
+ self.seconds = int(seconds)
+ self.nanos = int(nanos)
+
+ def GetCurrentTime(self):
+ """Get the current UTC into Timestamp."""
+ self.FromDatetime(datetime.utcnow())
+
+ def ToNanoseconds(self):
+ """Converts Timestamp to nanoseconds since epoch."""
+ return self.seconds * _NANOS_PER_SECOND + self.nanos
+
+ def ToMicroseconds(self):
+ """Converts Timestamp to microseconds since epoch."""
+ return (self.seconds * _MICROS_PER_SECOND +
+ self.nanos // _NANOS_PER_MICROSECOND)
+
+ def ToMilliseconds(self):
+ """Converts Timestamp to milliseconds since epoch."""
+ return (self.seconds * _MILLIS_PER_SECOND +
+ self.nanos // _NANOS_PER_MILLISECOND)
+
+ def ToSeconds(self):
+ """Converts Timestamp to seconds since epoch."""
+ return self.seconds
+
+ def FromNanoseconds(self, nanos):
+ """Converts nanoseconds since epoch to Timestamp."""
+ self.seconds = nanos // _NANOS_PER_SECOND
+ self.nanos = nanos % _NANOS_PER_SECOND
+
+ def FromMicroseconds(self, micros):
+ """Converts microseconds since epoch to Timestamp."""
+ self.seconds = micros // _MICROS_PER_SECOND
+ self.nanos = (micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND
+
+ def FromMilliseconds(self, millis):
+ """Converts milliseconds since epoch to Timestamp."""
+ self.seconds = millis // _MILLIS_PER_SECOND
+ self.nanos = (millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND
+
+ def FromSeconds(self, seconds):
+ """Converts seconds since epoch to Timestamp."""
+ self.seconds = seconds
+ self.nanos = 0
+
+ def ToDatetime(self):
+ """Converts Timestamp to datetime."""
+ return datetime.utcfromtimestamp(
+ self.seconds + self.nanos / float(_NANOS_PER_SECOND))
+
+ def FromDatetime(self, dt):
+ """Converts datetime to Timestamp."""
+ td = dt - datetime(1970, 1, 1)
+ self.seconds = td.seconds + td.days * _SECONDS_PER_DAY
+ self.nanos = td.microseconds * _NANOS_PER_MICROSECOND
+
+
+class Duration(object):
+ """Class for Duration message type."""
+
+ def ToJsonString(self):
+ """Converts Duration to string format.
+
+ Returns:
+ A string converted from self. The string format will contains
+ 3, 6, or 9 fractional digits depending on the precision required to
+ represent the exact Duration value. For example: "1s", "1.010s",
+ "1.000000100s", "-3.100s"
+ """
+ if self.seconds < 0 or self.nanos < 0:
+ result = '-'
+ seconds = - self.seconds + int((0 - self.nanos) // 1e9)
+ nanos = (0 - self.nanos) % 1e9
+ else:
+ result = ''
+ seconds = self.seconds + int(self.nanos // 1e9)
+ nanos = self.nanos % 1e9
+ result += '%d' % seconds
+ if (nanos % 1e9) == 0:
+ # If there are 0 fractional digits, the fractional
+ # point '.' should be omitted when serializing.
+ return result + 's'
+ if (nanos % 1e6) == 0:
+ # Serialize 3 fractional digits.
+ return result + '.%03ds' % (nanos / 1e6)
+ if (nanos % 1e3) == 0:
+ # Serialize 6 fractional digits.
+ return result + '.%06ds' % (nanos / 1e3)
+ # Serialize 9 fractional digits.
+ return result + '.%09ds' % nanos
+
+ def FromJsonString(self, value):
+ """Converts a string to Duration.
+
+ Args:
+ value: A string to be converted. The string must end with 's'. Any
+ fractional digits (or none) are accepted as long as they fit into
+ precision. For example: "1s", "1.01s", "1.0000001s", "-3.100s
+
+ Raises:
+ ParseError: On parsing problems.
+ """
+ if len(value) < 1 or value[-1] != 's':
+ raise ParseError(
+ 'Duration must end with letter "s": {0}.'.format(value))
+ try:
+ pos = value.find('.')
+ if pos == -1:
+ self.seconds = int(value[:-1])
+ self.nanos = 0
+ else:
+ self.seconds = int(value[:pos])
+ if value[0] == '-':
+ self.nanos = int(round(float('-0{0}'.format(value[pos: -1])) *1e9))
+ else:
+ self.nanos = int(round(float('0{0}'.format(value[pos: -1])) *1e9))
+ except ValueError:
+ raise ParseError(
+ 'Couldn\'t parse duration: {0}.'.format(value))
+
+ def ToNanoseconds(self):
+ """Converts a Duration to nanoseconds."""
+ return self.seconds * _NANOS_PER_SECOND + self.nanos
+
+ def ToMicroseconds(self):
+ """Converts a Duration to microseconds."""
+ micros = _RoundTowardZero(self.nanos, _NANOS_PER_MICROSECOND)
+ return self.seconds * _MICROS_PER_SECOND + micros
+
+ def ToMilliseconds(self):
+ """Converts a Duration to milliseconds."""
+ millis = _RoundTowardZero(self.nanos, _NANOS_PER_MILLISECOND)
+ return self.seconds * _MILLIS_PER_SECOND + millis
+
+ def ToSeconds(self):
+ """Converts a Duration to seconds."""
+ return self.seconds
+
+ def FromNanoseconds(self, nanos):
+ """Converts nanoseconds to Duration."""
+ self._NormalizeDuration(nanos // _NANOS_PER_SECOND,
+ nanos % _NANOS_PER_SECOND)
+
+ def FromMicroseconds(self, micros):
+ """Converts microseconds to Duration."""
+ self._NormalizeDuration(
+ micros // _MICROS_PER_SECOND,
+ (micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND)
+
+ def FromMilliseconds(self, millis):
+ """Converts milliseconds to Duration."""
+ self._NormalizeDuration(
+ millis // _MILLIS_PER_SECOND,
+ (millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND)
+
+ def FromSeconds(self, seconds):
+ """Converts seconds to Duration."""
+ self.seconds = seconds
+ self.nanos = 0
+
+ def ToTimedelta(self):
+ """Converts Duration to timedelta."""
+ return timedelta(
+ seconds=self.seconds, microseconds=_RoundTowardZero(
+ self.nanos, _NANOS_PER_MICROSECOND))
+
+ def FromTimedelta(self, td):
+ """Convertd timedelta to Duration."""
+ self._NormalizeDuration(td.seconds + td.days * _SECONDS_PER_DAY,
+ td.microseconds * _NANOS_PER_MICROSECOND)
+
+ def _NormalizeDuration(self, seconds, nanos):
+ """Set Duration by seconds and nonas."""
+ # Force nanos to be negative if the duration is negative.
+ if seconds < 0 and nanos > 0:
+ seconds += 1
+ nanos -= _NANOS_PER_SECOND
+ self.seconds = seconds
+ self.nanos = nanos
+
+
+def _RoundTowardZero(value, divider):
+ """Truncates the remainder part after division."""
+ # For some languanges, the sign of the remainder is implementation
+ # dependent if any of the operands is negative. Here we enforce
+ # "rounded toward zero" semantics. For example, for (-5) / 2 an
+ # implementation may give -3 as the result with the remainder being
+ # 1. This function ensures we always return -2 (closer to zero).
+ result = value // divider
+ remainder = value % divider
+ if result < 0 and remainder > 0:
+ return result + 1
+ else:
+ return result
+
+
+class FieldMask(object):
+ """Class for FieldMask message type."""
+
+ def ToJsonString(self):
+ """Converts FieldMask to string according to proto3 JSON spec."""
+ return ','.join(self.paths)
+
+ def FromJsonString(self, value):
+ """Converts string to FieldMask according to proto3 JSON spec."""
+ self.Clear()
+ for path in value.split(','):
+ self.paths.append(path)
+
+ def IsValidForDescriptor(self, message_descriptor):
+ """Checks whether the FieldMask is valid for Message Descriptor."""
+ for path in self.paths:
+ if not _IsValidPath(message_descriptor, path):
+ return False
+ return True
+
+ def AllFieldsFromDescriptor(self, message_descriptor):
+ """Gets all direct fields of Message Descriptor to FieldMask."""
+ self.Clear()
+ for field in message_descriptor.fields:
+ self.paths.append(field.name)
+
+ def CanonicalFormFromMask(self, mask):
+ """Converts a FieldMask to the canonical form.
+
+ Removes paths that are covered by another path. For example,
+ "foo.bar" is covered by "foo" and will be removed if "foo"
+ is also in the FieldMask. Then sorts all paths in alphabetical order.
+
+ Args:
+ mask: The original FieldMask to be converted.
+ """
+ tree = _FieldMaskTree(mask)
+ tree.ToFieldMask(self)
+
+ def Union(self, mask1, mask2):
+ """Merges mask1 and mask2 into this FieldMask."""
+ _CheckFieldMaskMessage(mask1)
+ _CheckFieldMaskMessage(mask2)
+ tree = _FieldMaskTree(mask1)
+ tree.MergeFromFieldMask(mask2)
+ tree.ToFieldMask(self)
+
+ def Intersect(self, mask1, mask2):
+ """Intersects mask1 and mask2 into this FieldMask."""
+ _CheckFieldMaskMessage(mask1)
+ _CheckFieldMaskMessage(mask2)
+ tree = _FieldMaskTree(mask1)
+ intersection = _FieldMaskTree()
+ for path in mask2.paths:
+ tree.IntersectPath(path, intersection)
+ intersection.ToFieldMask(self)
+
+ def MergeMessage(
+ self, source, destination,
+ replace_message_field=False, replace_repeated_field=False):
+ """Merges fields specified in FieldMask from source to destination.
+
+ Args:
+ source: Source message.
+ destination: The destination message to be merged into.
+ replace_message_field: Replace message field if True. Merge message
+ field if False.
+ replace_repeated_field: Replace repeated field if True. Append
+ elements of repeated field if False.
+ """
+ tree = _FieldMaskTree(self)
+ tree.MergeMessage(
+ source, destination, replace_message_field, replace_repeated_field)
+
+
+def _IsValidPath(message_descriptor, path):
+ """Checks whether the path is valid for Message Descriptor."""
+ parts = path.split('.')
+ last = parts.pop()
+ for name in parts:
+ field = message_descriptor.fields_by_name[name]
+ if (field is None or
+ field.label == FieldDescriptor.LABEL_REPEATED or
+ field.type != FieldDescriptor.TYPE_MESSAGE):
+ return False
+ message_descriptor = field.message_type
+ return last in message_descriptor.fields_by_name
+
+
+def _CheckFieldMaskMessage(message):
+ """Raises ValueError if message is not a FieldMask."""
+ message_descriptor = message.DESCRIPTOR
+ if (message_descriptor.name != 'FieldMask' or
+ message_descriptor.file.name != 'google/protobuf/field_mask.proto'):
+ raise ValueError('Message {0} is not a FieldMask.'.format(
+ message_descriptor.full_name))
+
+
+class _FieldMaskTree(object):
+ """Represents a FieldMask in a tree structure.
+
+ For example, given a FieldMask "foo.bar,foo.baz,bar.baz",
+ the FieldMaskTree will be:
+ [_root] -+- foo -+- bar
+ | |
+ | +- baz
+ |
+ +- bar --- baz
+ In the tree, each leaf node represents a field path.
+ """
+
+ def __init__(self, field_mask=None):
+ """Initializes the tree by FieldMask."""
+ self._root = {}
+ if field_mask:
+ self.MergeFromFieldMask(field_mask)
+
+ def MergeFromFieldMask(self, field_mask):
+ """Merges a FieldMask to the tree."""
+ for path in field_mask.paths:
+ self.AddPath(path)
+
+ def AddPath(self, path):
+ """Adds a field path into the tree.
+
+ If the field path to add is a sub-path of an existing field path
+ in the tree (i.e., a leaf node), it means the tree already matches
+ the given path so nothing will be added to the tree. If the path
+ matches an existing non-leaf node in the tree, that non-leaf node
+ will be turned into a leaf node with all its children removed because
+ the path matches all the node's children. Otherwise, a new path will
+ be added.
+
+ Args:
+ path: The field path to add.
+ """
+ node = self._root
+ for name in path.split('.'):
+ if name not in node:
+ node[name] = {}
+ elif not node[name]:
+ # Pre-existing empty node implies we already have this entire tree.
+ return
+ node = node[name]
+ # Remove any sub-trees we might have had.
+ node.clear()
+
+ def ToFieldMask(self, field_mask):
+ """Converts the tree to a FieldMask."""
+ field_mask.Clear()
+ _AddFieldPaths(self._root, '', field_mask)
+
+ def IntersectPath(self, path, intersection):
+ """Calculates the intersection part of a field path with this tree.
+
+ Args:
+ path: The field path to calculates.
+ intersection: The out tree to record the intersection part.
+ """
+ node = self._root
+ for name in path.split('.'):
+ if name not in node:
+ return
+ elif not node[name]:
+ intersection.AddPath(path)
+ return
+ node = node[name]
+ intersection.AddLeafNodes(path, node)
+
+ def AddLeafNodes(self, prefix, node):
+ """Adds leaf nodes begin with prefix to this tree."""
+ if not node:
+ self.AddPath(prefix)
+ for name in node:
+ child_path = prefix + '.' + name
+ self.AddLeafNodes(child_path, node[name])
+
+ def MergeMessage(
+ self, source, destination,
+ replace_message, replace_repeated):
+ """Merge all fields specified by this tree from source to destination."""
+ _MergeMessage(
+ self._root, source, destination, replace_message, replace_repeated)
+
+
+def _StrConvert(value):
+ """Converts value to str if it is not."""
+ # This file is imported by c extension and some methods like ClearField
+ # requires string for the field name. py2/py3 has different text
+ # type and may use unicode.
+ if not isinstance(value, str):
+ return value.encode('utf-8')
+ return value
+
+
+def _MergeMessage(
+ node, source, destination, replace_message, replace_repeated):
+ """Merge all fields specified by a sub-tree from source to destination."""
+ source_descriptor = source.DESCRIPTOR
+ for name in node:
+ child = node[name]
+ field = source_descriptor.fields_by_name[name]
+ if field is None:
+ raise ValueError('Error: Can\'t find field {0} in message {1}.'.format(
+ name, source_descriptor.full_name))
+ if child:
+ # Sub-paths are only allowed for singular message fields.
+ if (field.label == FieldDescriptor.LABEL_REPEATED or
+ field.cpp_type != FieldDescriptor.CPPTYPE_MESSAGE):
+ raise ValueError('Error: Field {0} in message {1} is not a singular '
+ 'message field and cannot have sub-fields.'.format(
+ name, source_descriptor.full_name))
+ _MergeMessage(
+ child, getattr(source, name), getattr(destination, name),
+ replace_message, replace_repeated)
+ continue
+ if field.label == FieldDescriptor.LABEL_REPEATED:
+ if replace_repeated:
+ destination.ClearField(_StrConvert(name))
+ repeated_source = getattr(source, name)
+ repeated_destination = getattr(destination, name)
+ if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE:
+ for item in repeated_source:
+ repeated_destination.add().MergeFrom(item)
+ else:
+ repeated_destination.extend(repeated_source)
+ else:
+ if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE:
+ if replace_message:
+ destination.ClearField(_StrConvert(name))
+ if source.HasField(name):
+ getattr(destination, name).MergeFrom(getattr(source, name))
+ else:
+ setattr(destination, name, getattr(source, name))
+
+
+def _AddFieldPaths(node, prefix, field_mask):
+ """Adds the field paths descended from node to field_mask."""
+ if not node:
+ field_mask.paths.append(prefix)
+ return
+ for name in sorted(node):
+ if prefix:
+ child_path = prefix + '.' + name
+ else:
+ child_path = name
+ _AddFieldPaths(node[name], child_path, field_mask)
+
+
+WKTBASES = {
+ 'google.protobuf.Any': Any,
+ 'google.protobuf.Duration': Duration,
+ 'google.protobuf.FieldMask': FieldMask,
+ 'google.protobuf.Timestamp': Timestamp,
+}
diff --git a/python/google/protobuf/internal/well_known_types_test.py b/python/google/protobuf/internal/well_known_types_test.py
new file mode 100644
index 00000000..60b0c47d
--- /dev/null
+++ b/python/google/protobuf/internal/well_known_types_test.py
@@ -0,0 +1,509 @@
+#! /usr/bin/env python
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# https://developers.google.com/protocol-buffers/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Test for google.protobuf.internal.well_known_types."""
+
+__author__ = 'jieluo@google.com (Jie Luo)'
+
+from datetime import datetime
+
+from google.protobuf import duration_pb2
+from google.protobuf import field_mask_pb2
+from google.protobuf import timestamp_pb2
+import unittest
+from google.protobuf import unittest_pb2
+from google.protobuf.internal import test_util
+from google.protobuf.internal import well_known_types
+from google.protobuf import descriptor
+
+
+class TimeUtilTestBase(unittest.TestCase):
+
+ def CheckTimestampConversion(self, message, text):
+ self.assertEqual(text, message.ToJsonString())
+ parsed_message = timestamp_pb2.Timestamp()
+ parsed_message.FromJsonString(text)
+ self.assertEqual(message, parsed_message)
+
+ def CheckDurationConversion(self, message, text):
+ self.assertEqual(text, message.ToJsonString())
+ parsed_message = duration_pb2.Duration()
+ parsed_message.FromJsonString(text)
+ self.assertEqual(message, parsed_message)
+
+
+class TimeUtilTest(TimeUtilTestBase):
+
+ def testTimestampSerializeAndParse(self):
+ message = timestamp_pb2.Timestamp()
+ # Generated output should contain 3, 6, or 9 fractional digits.
+ message.seconds = 0
+ message.nanos = 0
+ self.CheckTimestampConversion(message, '1970-01-01T00:00:00Z')
+ message.nanos = 10000000
+ self.CheckTimestampConversion(message, '1970-01-01T00:00:00.010Z')
+ message.nanos = 10000
+ self.CheckTimestampConversion(message, '1970-01-01T00:00:00.000010Z')
+ message.nanos = 10
+ self.CheckTimestampConversion(message, '1970-01-01T00:00:00.000000010Z')
+ # Test min timestamps.
+ message.seconds = -62135596800
+ message.nanos = 0
+ self.CheckTimestampConversion(message, '0001-01-01T00:00:00Z')
+ # Test max timestamps.
+ message.seconds = 253402300799
+ message.nanos = 999999999
+ self.CheckTimestampConversion(message, '9999-12-31T23:59:59.999999999Z')
+ # Test negative timestamps.
+ message.seconds = -1
+ self.CheckTimestampConversion(message, '1969-12-31T23:59:59.999999999Z')
+
+ # Parsing accepts an fractional digits as long as they fit into nano
+ # precision.
+ message.FromJsonString('1970-01-01T00:00:00.1Z')
+ self.assertEqual(0, message.seconds)
+ self.assertEqual(100000000, message.nanos)
+ # Parsing accpets offsets.
+ message.FromJsonString('1970-01-01T00:00:00-08:00')
+ self.assertEqual(8 * 3600, message.seconds)
+ self.assertEqual(0, message.nanos)
+
+ def testDurationSerializeAndParse(self):
+ message = duration_pb2.Duration()
+ # Generated output should contain 3, 6, or 9 fractional digits.
+ message.seconds = 0
+ message.nanos = 0
+ self.CheckDurationConversion(message, '0s')
+ message.nanos = 10000000
+ self.CheckDurationConversion(message, '0.010s')
+ message.nanos = 10000
+ self.CheckDurationConversion(message, '0.000010s')
+ message.nanos = 10
+ self.CheckDurationConversion(message, '0.000000010s')
+
+ # Test min and max
+ message.seconds = 315576000000
+ message.nanos = 999999999
+ self.CheckDurationConversion(message, '315576000000.999999999s')
+ message.seconds = -315576000000
+ message.nanos = -999999999
+ self.CheckDurationConversion(message, '-315576000000.999999999s')
+
+ # Parsing accepts an fractional digits as long as they fit into nano
+ # precision.
+ message.FromJsonString('0.1s')
+ self.assertEqual(100000000, message.nanos)
+ message.FromJsonString('0.0000001s')
+ self.assertEqual(100, message.nanos)
+
+ def testTimestampIntegerConversion(self):
+ message = timestamp_pb2.Timestamp()
+ message.FromNanoseconds(1)
+ self.assertEqual('1970-01-01T00:00:00.000000001Z',
+ message.ToJsonString())
+ self.assertEqual(1, message.ToNanoseconds())
+
+ message.FromNanoseconds(-1)
+ self.assertEqual('1969-12-31T23:59:59.999999999Z',
+ message.ToJsonString())
+ self.assertEqual(-1, message.ToNanoseconds())
+
+ message.FromMicroseconds(1)
+ self.assertEqual('1970-01-01T00:00:00.000001Z',
+ message.ToJsonString())
+ self.assertEqual(1, message.ToMicroseconds())
+
+ message.FromMicroseconds(-1)
+ self.assertEqual('1969-12-31T23:59:59.999999Z',
+ message.ToJsonString())
+ self.assertEqual(-1, message.ToMicroseconds())
+
+ message.FromMilliseconds(1)
+ self.assertEqual('1970-01-01T00:00:00.001Z',
+ message.ToJsonString())
+ self.assertEqual(1, message.ToMilliseconds())
+
+ message.FromMilliseconds(-1)
+ self.assertEqual('1969-12-31T23:59:59.999Z',
+ message.ToJsonString())
+ self.assertEqual(-1, message.ToMilliseconds())
+
+ message.FromSeconds(1)
+ self.assertEqual('1970-01-01T00:00:01Z',
+ message.ToJsonString())
+ self.assertEqual(1, message.ToSeconds())
+
+ message.FromSeconds(-1)
+ self.assertEqual('1969-12-31T23:59:59Z',
+ message.ToJsonString())
+ self.assertEqual(-1, message.ToSeconds())
+
+ message.FromNanoseconds(1999)
+ self.assertEqual(1, message.ToMicroseconds())
+ # For negative values, Timestamp will be rounded down.
+ # For example, "1969-12-31T23:59:59.5Z" (i.e., -0.5s) rounded to seconds
+ # will be "1969-12-31T23:59:59Z" (i.e., -1s) rather than
+ # "1970-01-01T00:00:00Z" (i.e., 0s).
+ message.FromNanoseconds(-1999)
+ self.assertEqual(-2, message.ToMicroseconds())
+
+ def testDurationIntegerConversion(self):
+ message = duration_pb2.Duration()
+ message.FromNanoseconds(1)
+ self.assertEqual('0.000000001s',
+ message.ToJsonString())
+ self.assertEqual(1, message.ToNanoseconds())
+
+ message.FromNanoseconds(-1)
+ self.assertEqual('-0.000000001s',
+ message.ToJsonString())
+ self.assertEqual(-1, message.ToNanoseconds())
+
+ message.FromMicroseconds(1)
+ self.assertEqual('0.000001s',
+ message.ToJsonString())
+ self.assertEqual(1, message.ToMicroseconds())
+
+ message.FromMicroseconds(-1)
+ self.assertEqual('-0.000001s',
+ message.ToJsonString())
+ self.assertEqual(-1, message.ToMicroseconds())
+
+ message.FromMilliseconds(1)
+ self.assertEqual('0.001s',
+ message.ToJsonString())
+ self.assertEqual(1, message.ToMilliseconds())
+
+ message.FromMilliseconds(-1)
+ self.assertEqual('-0.001s',
+ message.ToJsonString())
+ self.assertEqual(-1, message.ToMilliseconds())
+
+ message.FromSeconds(1)
+ self.assertEqual('1s', message.ToJsonString())
+ self.assertEqual(1, message.ToSeconds())
+
+ message.FromSeconds(-1)
+ self.assertEqual('-1s',
+ message.ToJsonString())
+ self.assertEqual(-1, message.ToSeconds())
+
+ # Test truncation behavior.
+ message.FromNanoseconds(1999)
+ self.assertEqual(1, message.ToMicroseconds())
+
+ # For negative values, Duration will be rounded towards 0.
+ message.FromNanoseconds(-1999)
+ self.assertEqual(-1, message.ToMicroseconds())
+
+ def testDatetimeConverison(self):
+ message = timestamp_pb2.Timestamp()
+ dt = datetime(1970, 1, 1)
+ message.FromDatetime(dt)
+ self.assertEqual(dt, message.ToDatetime())
+
+ message.FromMilliseconds(1999)
+ self.assertEqual(datetime(1970, 1, 1, 0, 0, 1, 999000),
+ message.ToDatetime())
+
+ def testTimedeltaConversion(self):
+ message = duration_pb2.Duration()
+ message.FromNanoseconds(1999999999)
+ td = message.ToTimedelta()
+ self.assertEqual(1, td.seconds)
+ self.assertEqual(999999, td.microseconds)
+
+ message.FromNanoseconds(-1999999999)
+ td = message.ToTimedelta()
+ self.assertEqual(-1, td.days)
+ self.assertEqual(86398, td.seconds)
+ self.assertEqual(1, td.microseconds)
+
+ message.FromMicroseconds(-1)
+ td = message.ToTimedelta()
+ self.assertEqual(-1, td.days)
+ self.assertEqual(86399, td.seconds)
+ self.assertEqual(999999, td.microseconds)
+ converted_message = duration_pb2.Duration()
+ converted_message.FromTimedelta(td)
+ self.assertEqual(message, converted_message)
+
+ def testInvalidTimestamp(self):
+ message = timestamp_pb2.Timestamp()
+ self.assertRaisesRegexp(
+ ValueError,
+ 'time data \'10000-01-01T00:00:00\' does not match'
+ ' format \'%Y-%m-%dT%H:%M:%S\'',
+ message.FromJsonString, '10000-01-01T00:00:00.00Z')
+ self.assertRaisesRegexp(
+ well_known_types.ParseError,
+ 'nanos 0123456789012 more than 9 fractional digits.',
+ message.FromJsonString,
+ '1970-01-01T00:00:00.0123456789012Z')
+ self.assertRaisesRegexp(
+ well_known_types.ParseError,
+ (r'Invalid timezone offset value: \+08.'),
+ message.FromJsonString,
+ '1972-01-01T01:00:00.01+08',)
+ self.assertRaisesRegexp(
+ ValueError,
+ 'year is out of range',
+ message.FromJsonString,
+ '0000-01-01T00:00:00Z')
+ message.seconds = 253402300800
+ self.assertRaisesRegexp(
+ OverflowError,
+ 'date value out of range',
+ message.ToJsonString)
+
+ def testInvalidDuration(self):
+ message = duration_pb2.Duration()
+ self.assertRaisesRegexp(
+ well_known_types.ParseError,
+ 'Duration must end with letter "s": 1.',
+ message.FromJsonString, '1')
+ self.assertRaisesRegexp(
+ well_known_types.ParseError,
+ 'Couldn\'t parse duration: 1...2s.',
+ message.FromJsonString, '1...2s')
+
+
+class FieldMaskTest(unittest.TestCase):
+
+ def testStringFormat(self):
+ mask = field_mask_pb2.FieldMask()
+ self.assertEqual('', mask.ToJsonString())
+ mask.paths.append('foo')
+ self.assertEqual('foo', mask.ToJsonString())
+ mask.paths.append('bar')
+ self.assertEqual('foo,bar', mask.ToJsonString())
+
+ mask.FromJsonString('')
+ self.assertEqual('', mask.ToJsonString())
+ mask.FromJsonString('foo')
+ self.assertEqual(['foo'], mask.paths)
+ mask.FromJsonString('foo,bar')
+ self.assertEqual(['foo', 'bar'], mask.paths)
+
+ def testDescriptorToFieldMask(self):
+ mask = field_mask_pb2.FieldMask()
+ msg_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
+ mask.AllFieldsFromDescriptor(msg_descriptor)
+ self.assertEqual(75, len(mask.paths))
+ self.assertTrue(mask.IsValidForDescriptor(msg_descriptor))
+ for field in msg_descriptor.fields:
+ self.assertTrue(field.name in mask.paths)
+ mask.paths.append('optional_nested_message.bb')
+ self.assertTrue(mask.IsValidForDescriptor(msg_descriptor))
+ mask.paths.append('repeated_nested_message.bb')
+ self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
+
+ def testCanonicalFrom(self):
+ mask = field_mask_pb2.FieldMask()
+ out_mask = field_mask_pb2.FieldMask()
+ # Paths will be sorted.
+ mask.FromJsonString('baz.quz,bar,foo')
+ out_mask.CanonicalFormFromMask(mask)
+ self.assertEqual('bar,baz.quz,foo', out_mask.ToJsonString())
+ # Duplicated paths will be removed.
+ mask.FromJsonString('foo,bar,foo')
+ out_mask.CanonicalFormFromMask(mask)
+ self.assertEqual('bar,foo', out_mask.ToJsonString())
+ # Sub-paths of other paths will be removed.
+ mask.FromJsonString('foo.b1,bar.b1,foo.b2,bar')
+ out_mask.CanonicalFormFromMask(mask)
+ self.assertEqual('bar,foo.b1,foo.b2', out_mask.ToJsonString())
+
+ # Test more deeply nested cases.
+ mask.FromJsonString(
+ 'foo.bar.baz1,foo.bar.baz2.quz,foo.bar.baz2')
+ out_mask.CanonicalFormFromMask(mask)
+ self.assertEqual('foo.bar.baz1,foo.bar.baz2',
+ out_mask.ToJsonString())
+ mask.FromJsonString(
+ 'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz')
+ out_mask.CanonicalFormFromMask(mask)
+ self.assertEqual('foo.bar.baz1,foo.bar.baz2',
+ out_mask.ToJsonString())
+ mask.FromJsonString(
+ 'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz,foo.bar')
+ out_mask.CanonicalFormFromMask(mask)
+ self.assertEqual('foo.bar', out_mask.ToJsonString())
+ mask.FromJsonString(
+ 'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz,foo')
+ out_mask.CanonicalFormFromMask(mask)
+ self.assertEqual('foo', out_mask.ToJsonString())
+
+ def testUnion(self):
+ mask1 = field_mask_pb2.FieldMask()
+ mask2 = field_mask_pb2.FieldMask()
+ out_mask = field_mask_pb2.FieldMask()
+ mask1.FromJsonString('foo,baz')
+ mask2.FromJsonString('bar,quz')
+ out_mask.Union(mask1, mask2)
+ self.assertEqual('bar,baz,foo,quz', out_mask.ToJsonString())
+ # Overlap with duplicated paths.
+ mask1.FromJsonString('foo,baz.bb')
+ mask2.FromJsonString('baz.bb,quz')
+ out_mask.Union(mask1, mask2)
+ self.assertEqual('baz.bb,foo,quz', out_mask.ToJsonString())
+ # Overlap with paths covering some other paths.
+ mask1.FromJsonString('foo.bar.baz,quz')
+ mask2.FromJsonString('foo.bar,bar')
+ out_mask.Union(mask1, mask2)
+ self.assertEqual('bar,foo.bar,quz', out_mask.ToJsonString())
+
+ def testIntersect(self):
+ mask1 = field_mask_pb2.FieldMask()
+ mask2 = field_mask_pb2.FieldMask()
+ out_mask = field_mask_pb2.FieldMask()
+ # Test cases without overlapping.
+ mask1.FromJsonString('foo,baz')
+ mask2.FromJsonString('bar,quz')
+ out_mask.Intersect(mask1, mask2)
+ self.assertEqual('', out_mask.ToJsonString())
+ # Overlap with duplicated paths.
+ mask1.FromJsonString('foo,baz.bb')
+ mask2.FromJsonString('baz.bb,quz')
+ out_mask.Intersect(mask1, mask2)
+ self.assertEqual('baz.bb', out_mask.ToJsonString())
+ # Overlap with paths covering some other paths.
+ mask1.FromJsonString('foo.bar.baz,quz')
+ mask2.FromJsonString('foo.bar,bar')
+ out_mask.Intersect(mask1, mask2)
+ self.assertEqual('foo.bar.baz', out_mask.ToJsonString())
+ mask1.FromJsonString('foo.bar,bar')
+ mask2.FromJsonString('foo.bar.baz,quz')
+ out_mask.Intersect(mask1, mask2)
+ self.assertEqual('foo.bar.baz', out_mask.ToJsonString())
+
+ def testMergeMessage(self):
+ # Test merge one field.
+ src = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(src)
+ for field in src.DESCRIPTOR.fields:
+ if field.containing_oneof:
+ continue
+ field_name = field.name
+ dst = unittest_pb2.TestAllTypes()
+ # Only set one path to mask.
+ mask = field_mask_pb2.FieldMask()
+ mask.paths.append(field_name)
+ mask.MergeMessage(src, dst)
+ # The expected result message.
+ msg = unittest_pb2.TestAllTypes()
+ if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ repeated_src = getattr(src, field_name)
+ repeated_msg = getattr(msg, field_name)
+ if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ for item in repeated_src:
+ repeated_msg.add().CopyFrom(item)
+ else:
+ repeated_msg.extend(repeated_src)
+ elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ getattr(msg, field_name).CopyFrom(getattr(src, field_name))
+ else:
+ setattr(msg, field_name, getattr(src, field_name))
+ # Only field specified in mask is merged.
+ self.assertEqual(msg, dst)
+
+ # Test merge nested fields.
+ nested_src = unittest_pb2.NestedTestAllTypes()
+ nested_dst = unittest_pb2.NestedTestAllTypes()
+ nested_src.child.payload.optional_int32 = 1234
+ nested_src.child.child.payload.optional_int32 = 5678
+ mask = field_mask_pb2.FieldMask()
+ mask.FromJsonString('child.payload')
+ mask.MergeMessage(nested_src, nested_dst)
+ self.assertEqual(1234, nested_dst.child.payload.optional_int32)
+ self.assertEqual(0, nested_dst.child.child.payload.optional_int32)
+
+ mask.FromJsonString('child.child.payload')
+ mask.MergeMessage(nested_src, nested_dst)
+ self.assertEqual(1234, nested_dst.child.payload.optional_int32)
+ self.assertEqual(5678, nested_dst.child.child.payload.optional_int32)
+
+ nested_dst.Clear()
+ mask.FromJsonString('child.child.payload')
+ mask.MergeMessage(nested_src, nested_dst)
+ self.assertEqual(0, nested_dst.child.payload.optional_int32)
+ self.assertEqual(5678, nested_dst.child.child.payload.optional_int32)
+
+ nested_dst.Clear()
+ mask.FromJsonString('child')
+ mask.MergeMessage(nested_src, nested_dst)
+ self.assertEqual(1234, nested_dst.child.payload.optional_int32)
+ self.assertEqual(5678, nested_dst.child.child.payload.optional_int32)
+
+ # Test MergeOptions.
+ nested_dst.Clear()
+ nested_dst.child.payload.optional_int64 = 4321
+ # Message fields will be merged by default.
+ mask.FromJsonString('child.payload')
+ mask.MergeMessage(nested_src, nested_dst)
+ self.assertEqual(1234, nested_dst.child.payload.optional_int32)
+ self.assertEqual(4321, nested_dst.child.payload.optional_int64)
+ # Change the behavior to replace message fields.
+ mask.FromJsonString('child.payload')
+ mask.MergeMessage(nested_src, nested_dst, True, False)
+ self.assertEqual(1234, nested_dst.child.payload.optional_int32)
+ self.assertEqual(0, nested_dst.child.payload.optional_int64)
+
+ # By default, fields missing in source are not cleared in destination.
+ nested_dst.payload.optional_int32 = 1234
+ self.assertTrue(nested_dst.HasField('payload'))
+ mask.FromJsonString('payload')
+ mask.MergeMessage(nested_src, nested_dst)
+ self.assertTrue(nested_dst.HasField('payload'))
+ # But they are cleared when replacing message fields.
+ nested_dst.Clear()
+ nested_dst.payload.optional_int32 = 1234
+ mask.FromJsonString('payload')
+ mask.MergeMessage(nested_src, nested_dst, True, False)
+ self.assertFalse(nested_dst.HasField('payload'))
+
+ nested_src.payload.repeated_int32.append(1234)
+ nested_dst.payload.repeated_int32.append(5678)
+ # Repeated fields will be appended by default.
+ mask.FromJsonString('payload.repeated_int32')
+ mask.MergeMessage(nested_src, nested_dst)
+ self.assertEqual(2, len(nested_dst.payload.repeated_int32))
+ self.assertEqual(5678, nested_dst.payload.repeated_int32[0])
+ self.assertEqual(1234, nested_dst.payload.repeated_int32[1])
+ # Change the behavior to replace repeated fields.
+ mask.FromJsonString('payload.repeated_int32')
+ mask.MergeMessage(nested_src, nested_dst, False, True)
+ self.assertEqual(1, len(nested_dst.payload.repeated_int32))
+ self.assertEqual(1234, nested_dst.payload.repeated_int32[0])
+
+if __name__ == '__main__':
+ unittest.main()