aboutsummaryrefslogtreecommitdiffhomepage
path: root/python
diff options
context:
space:
mode:
authorGravatar Feng Xiao <xfxyjwf@gmail.com>2015-12-11 17:09:20 -0800
committerGravatar Feng Xiao <xfxyjwf@gmail.com>2015-12-11 17:10:28 -0800
commite841bac4fcf47f809e089a70d5f84ac37b3883df (patch)
treed25dc5fc814db182c04c5f276ff1a609c5965a5a /python
parent99a6a95c751a28a3cc33dd2384959179f83f682c (diff)
Down-integrate from internal code base.
Diffstat (limited to 'python')
-rwxr-xr-xpython/google/protobuf/descriptor.py12
-rw-r--r--python/google/protobuf/descriptor_database.py4
-rw-r--r--python/google/protobuf/descriptor_pool.py69
-rw-r--r--python/google/protobuf/internal/any_test.proto42
-rwxr-xr-xpython/google/protobuf/internal/containers.py23
-rw-r--r--python/google/protobuf/internal/descriptor_pool_test.py176
-rwxr-xr-xpython/google/protobuf/internal/descriptor_test.py20
-rw-r--r--python/google/protobuf/internal/json_format_test.py58
-rw-r--r--python/google/protobuf/internal/message_factory_test.py7
-rw-r--r--python/google/protobuf/internal/message_set_extensions.proto8
-rwxr-xr-xpython/google/protobuf/internal/message_test.py81
-rwxr-xr-xpython/google/protobuf/internal/python_message.py45
-rwxr-xr-xpython/google/protobuf/internal/reflection_test.py18
-rwxr-xr-xpython/google/protobuf/internal/text_format_test.py170
-rw-r--r--python/google/protobuf/internal/well_known_types.py622
-rw-r--r--python/google/protobuf/internal/well_known_types_test.py509
-rw-r--r--python/google/protobuf/json_format.py269
-rw-r--r--python/google/protobuf/message_factory.py4
-rw-r--r--python/google/protobuf/pyext/descriptor.cc29
-rw-r--r--python/google/protobuf/pyext/descriptor_database.cc145
-rw-r--r--python/google/protobuf/pyext/descriptor_database.h75
-rw-r--r--python/google/protobuf/pyext/descriptor_pool.cc111
-rw-r--r--python/google/protobuf/pyext/descriptor_pool.h10
-rw-r--r--python/google/protobuf/pyext/extension_dict.cc47
-rw-r--r--python/google/protobuf/pyext/extension_dict.h5
-rw-r--r--python/google/protobuf/pyext/map_container.cc912
-rw-r--r--python/google/protobuf/pyext/map_container.h (renamed from python/google/protobuf/pyext/message_map_container.h)73
-rw-r--r--python/google/protobuf/pyext/message.cc299
-rw-r--r--python/google/protobuf/pyext/message.h1
-rw-r--r--python/google/protobuf/pyext/message_map_container.cc569
-rw-r--r--python/google/protobuf/pyext/scalar_map_container.cc542
-rw-r--r--python/google/protobuf/pyext/scalar_map_container.h119
-rw-r--r--python/google/protobuf/symbol_database.py32
-rwxr-xr-xpython/google/protobuf/text_format.py264
-rwxr-xr-xpython/setup.py1
35 files changed, 3564 insertions, 1807 deletions
diff --git a/python/google/protobuf/descriptor.py b/python/google/protobuf/descriptor.py
index 2bf36532..5f613c88 100755
--- a/python/google/protobuf/descriptor.py
+++ b/python/google/protobuf/descriptor.py
@@ -786,25 +786,33 @@ class FileDescriptor(DescriptorBase):
message_types_by_name: Dict of message names of their descriptors.
enum_types_by_name: Dict of enum names and their descriptors.
extensions_by_name: Dict of extension names and their descriptors.
+ pool: the DescriptorPool this descriptor belongs to. When not passed to the
+ constructor, the global default pool is used.
"""
if _USE_C_DESCRIPTORS:
_C_DESCRIPTOR_CLASS = _message.FileDescriptor
def __new__(cls, name, package, options=None, serialized_pb=None,
- dependencies=None, syntax=None):
+ dependencies=None, syntax=None, pool=None):
# FileDescriptor() is called from various places, not only from generated
# files, to register dynamic proto files and messages.
if serialized_pb:
+ # TODO(amauryfa): use the pool passed as argument. This will work only
+ # for C++-implemented DescriptorPools.
return _message.default_pool.AddSerializedFile(serialized_pb)
else:
return super(FileDescriptor, cls).__new__(cls)
def __init__(self, name, package, options=None, serialized_pb=None,
- dependencies=None, syntax=None):
+ dependencies=None, syntax=None, pool=None):
"""Constructor."""
super(FileDescriptor, self).__init__(options, 'FileOptions')
+ if pool is None:
+ from google.protobuf import descriptor_pool
+ pool = descriptor_pool.Default()
+ self.pool = pool
self.message_types_by_name = {}
self.name = name
self.package = package
diff --git a/python/google/protobuf/descriptor_database.py b/python/google/protobuf/descriptor_database.py
index b10021e9..1333f996 100644
--- a/python/google/protobuf/descriptor_database.py
+++ b/python/google/protobuf/descriptor_database.py
@@ -65,6 +65,7 @@ class DescriptorDatabase(object):
raise DescriptorDatabaseConflictingDefinitionError(
'%s already added, but with different descriptor.' % proto_name)
+ # Add the top-level Message, Enum and Extension descriptors to the index.
package = file_desc_proto.package
for message in file_desc_proto.message_type:
self._file_desc_protos_by_symbol.update(
@@ -72,6 +73,9 @@ class DescriptorDatabase(object):
for enum in file_desc_proto.enum_type:
self._file_desc_protos_by_symbol[
'.'.join((package, enum.name))] = file_desc_proto
+ for extension in file_desc_proto.extension:
+ self._file_desc_protos_by_symbol[
+ '.'.join((package, extension.name))] = file_desc_proto
def FindFileByName(self, name):
"""Finds the file descriptor proto by file name.
diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py
index 6a1b4b5e..3e80795c 100644
--- a/python/google/protobuf/descriptor_pool.py
+++ b/python/google/protobuf/descriptor_pool.py
@@ -83,6 +83,12 @@ def _NormalizeFullyQualifiedName(name):
class DescriptorPool(object):
"""A collection of protobufs dynamically constructed by descriptor protos."""
+ if _USE_C_DESCRIPTORS:
+
+ def __new__(cls, descriptor_db=None):
+ # pylint: disable=protected-access
+ return descriptor._message.DescriptorPool(descriptor_db)
+
def __init__(self, descriptor_db=None):
"""Initializes a Pool of proto buffs.
@@ -264,6 +270,39 @@ class DescriptorPool(object):
self.FindFileContainingSymbol(full_name)
return self._enum_descriptors[full_name]
+ def FindFieldByName(self, full_name):
+ """Loads the named field descriptor from the pool.
+
+ Args:
+ full_name: The full name of the field descriptor to load.
+
+ Returns:
+ The field descriptor for the named field.
+ """
+ full_name = _NormalizeFullyQualifiedName(full_name)
+ message_name, _, field_name = full_name.rpartition('.')
+ message_descriptor = self.FindMessageTypeByName(message_name)
+ return message_descriptor.fields_by_name[field_name]
+
+ def FindExtensionByName(self, full_name):
+ """Loads the named extension descriptor from the pool.
+
+ Args:
+ full_name: The full name of the extension descriptor to load.
+
+ Returns:
+ A FieldDescriptor, describing the named extension.
+ """
+ full_name = _NormalizeFullyQualifiedName(full_name)
+ message_name, _, extension_name = full_name.rpartition('.')
+ try:
+ # Most extensions are nested inside a message.
+ scope = self.FindMessageTypeByName(message_name)
+ except KeyError:
+ # Some extensions are defined at file scope.
+ scope = self.FindFileContainingSymbol(full_name)
+ return scope.extensions_by_name[extension_name]
+
def _ConvertFileProtoToFileDescriptor(self, file_proto):
"""Creates a FileDescriptor from a proto or returns a cached copy.
@@ -282,6 +321,7 @@ class DescriptorPool(object):
direct_deps = [self.FindFileByName(n) for n in file_proto.dependency]
file_descriptor = descriptor.FileDescriptor(
+ pool=self,
name=file_proto.name,
package=file_proto.package,
syntax=file_proto.syntax,
@@ -598,10 +638,24 @@ class DescriptorPool(object):
field_desc.default_value = text_encoding.CUnescape(
field_proto.default_value)
else:
+ # All other types are of the "int" type.
field_desc.default_value = int(field_proto.default_value)
else:
field_desc.has_default_value = False
- field_desc.default_value = None
+ if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or
+ field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT):
+ field_desc.default_value = 0.0
+ elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING:
+ field_desc.default_value = u''
+ elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL:
+ field_desc.default_value = False
+ elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
+ field_desc.default_value = field_desc.enum_type.values[0].number
+ elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
+ field_desc.default_value = b''
+ else:
+ # All other types are of the "int" type.
+ field_desc.default_value = 0
field_desc.type = field_proto.type
@@ -680,3 +734,16 @@ class DescriptorPool(object):
def _PrefixWithDot(name):
return name if name.startswith('.') else '.%s' % name
+
+
+if _USE_C_DESCRIPTORS:
+ # TODO(amauryfa): This pool could be constructed from Python code, when we
+ # support a flag like 'use_cpp_generated_pool=True'.
+ # pylint: disable=protected-access
+ _DEFAULT = descriptor._message.default_pool
+else:
+ _DEFAULT = DescriptorPool()
+
+
+def Default():
+ return _DEFAULT
diff --git a/python/google/protobuf/internal/any_test.proto b/python/google/protobuf/internal/any_test.proto
new file mode 100644
index 00000000..cd641ca0
--- /dev/null
+++ b/python/google/protobuf/internal/any_test.proto
@@ -0,0 +1,42 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: jieluo@google.com (Jie Luo)
+
+syntax = "proto3";
+
+package google.protobuf.internal;
+
+import "google/protobuf/any.proto";
+
+message TestAny {
+ google.protobuf.Any value = 1;
+ int32 int_value = 2;
+}
diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py
index 9c8275eb..97cdd848 100755
--- a/python/google/protobuf/internal/containers.py
+++ b/python/google/protobuf/internal/containers.py
@@ -464,6 +464,9 @@ class ScalarMap(MutableMapping):
return val
def __contains__(self, item):
+ # We check the key's type to match the strong-typing flavor of the API.
+ # Also this makes it easier to match the behavior of the C++ implementation.
+ self._key_checker.CheckValue(item)
return item in self._values
# We need to override this explicitly, because our defaultdict-like behavior
@@ -491,10 +494,20 @@ class ScalarMap(MutableMapping):
def __iter__(self):
return iter(self._values)
+ def __repr__(self):
+ return repr(self._values)
+
def MergeFrom(self, other):
self._values.update(other._values)
self._message_listener.Modified()
+ def InvalidateIterators(self):
+ # It appears that the only way to reliably invalidate iterators to
+ # self._values is to ensure that its size changes.
+ original = self._values
+ self._values = original.copy()
+ original[None] = None
+
# This is defined in the abstract base, but we can do it much more cheaply.
def clear(self):
self._values.clear()
@@ -576,12 +589,22 @@ class MessageMap(MutableMapping):
def __iter__(self):
return iter(self._values)
+ def __repr__(self):
+ return repr(self._values)
+
def MergeFrom(self, other):
for key in other:
self[key].MergeFrom(other[key])
# self._message_listener.Modified() not required here, because
# mutations to submessages already propagate.
+ def InvalidateIterators(self):
+ # It appears that the only way to reliably invalidate iterators to
+ # self._values is to ensure that its size changes.
+ original = self._values
+ self._values = original.copy()
+ original[None] = None
+
# This is defined in the abstract base, but we can do it much more cheaply.
def clear(self):
self._values.clear()
diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py
index da9a78db..f1d6bf99 100644
--- a/python/google/protobuf/internal/descriptor_pool_test.py
+++ b/python/google/protobuf/internal/descriptor_pool_test.py
@@ -40,6 +40,8 @@ try:
import unittest2 as unittest
except ImportError:
import unittest
+from google.protobuf import unittest_import_pb2
+from google.protobuf import unittest_import_public_pb2
from google.protobuf import unittest_pb2
from google.protobuf import descriptor_pb2
from google.protobuf.internal import api_implementation
@@ -51,13 +53,17 @@ from google.protobuf.internal import test_util
from google.protobuf import descriptor
from google.protobuf import descriptor_database
from google.protobuf import descriptor_pool
+from google.protobuf import message_factory
from google.protobuf import symbol_database
class DescriptorPoolTest(unittest.TestCase):
+ def CreatePool(self):
+ return descriptor_pool.DescriptorPool()
+
def setUp(self):
- self.pool = descriptor_pool.DescriptorPool()
+ self.pool = self.CreatePool()
self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString(
factory_test1_pb2.DESCRIPTOR.serialized_pb)
self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString(
@@ -89,7 +95,7 @@ class DescriptorPoolTest(unittest.TestCase):
'google.protobuf.python.internal.Factory1Message')
self.assertIsInstance(file_desc1, descriptor.FileDescriptor)
self.assertEqual('google/protobuf/internal/factory_test1.proto',
- file_desc1.name)
+ file_desc1.name)
self.assertEqual('google.protobuf.python.internal', file_desc1.package)
self.assertIn('Factory1Message', file_desc1.message_types_by_name)
@@ -97,7 +103,7 @@ class DescriptorPoolTest(unittest.TestCase):
'google.protobuf.python.internal.Factory2Message')
self.assertIsInstance(file_desc2, descriptor.FileDescriptor)
self.assertEqual('google/protobuf/internal/factory_test2.proto',
- file_desc2.name)
+ file_desc2.name)
self.assertEqual('google.protobuf.python.internal', file_desc2.package)
self.assertIn('Factory2Message', file_desc2.message_types_by_name)
@@ -111,7 +117,7 @@ class DescriptorPoolTest(unittest.TestCase):
self.assertIsInstance(msg1, descriptor.Descriptor)
self.assertEqual('Factory1Message', msg1.name)
self.assertEqual('google.protobuf.python.internal.Factory1Message',
- msg1.full_name)
+ msg1.full_name)
self.assertEqual(None, msg1.containing_type)
nested_msg1 = msg1.nested_types[0]
@@ -132,7 +138,7 @@ class DescriptorPoolTest(unittest.TestCase):
self.assertIsInstance(msg2, descriptor.Descriptor)
self.assertEqual('Factory2Message', msg2.name)
self.assertEqual('google.protobuf.python.internal.Factory2Message',
- msg2.full_name)
+ msg2.full_name)
self.assertIsNone(msg2.containing_type)
nested_msg2 = msg2.nested_types[0]
@@ -223,6 +229,37 @@ class DescriptorPoolTest(unittest.TestCase):
with self.assertRaises(KeyError):
self.pool.FindEnumTypeByName('Does not exist')
+ def testFindFieldByName(self):
+ field = self.pool.FindFieldByName(
+ 'google.protobuf.python.internal.Factory1Message.list_value')
+ self.assertEqual(field.name, 'list_value')
+ self.assertEqual(field.label, field.LABEL_REPEATED)
+ with self.assertRaises(KeyError):
+ self.pool.FindFieldByName('Does not exist')
+
+ def testFindExtensionByName(self):
+ # An extension defined in a message.
+ extension = self.pool.FindExtensionByName(
+ 'google.protobuf.python.internal.Factory2Message.one_more_field')
+ self.assertEqual(extension.name, 'one_more_field')
+ # An extension defined at file scope.
+ extension = self.pool.FindExtensionByName(
+ 'google.protobuf.python.internal.another_field')
+ self.assertEqual(extension.name, 'another_field')
+ self.assertEqual(extension.number, 1002)
+ with self.assertRaises(KeyError):
+ self.pool.FindFieldByName('Does not exist')
+
+ def testExtensionsAreNotFields(self):
+ with self.assertRaises(KeyError):
+ self.pool.FindFieldByName('google.protobuf.python.internal.another_field')
+ with self.assertRaises(KeyError):
+ self.pool.FindFieldByName(
+ 'google.protobuf.python.internal.Factory2Message.one_more_field')
+ with self.assertRaises(KeyError):
+ self.pool.FindExtensionByName(
+ 'google.protobuf.python.internal.Factory1Message.list_value')
+
def testUserDefinedDB(self):
db = descriptor_database.DescriptorDatabase()
self.pool = descriptor_pool.DescriptorPool(db)
@@ -231,8 +268,7 @@ class DescriptorPoolTest(unittest.TestCase):
self.testFindMessageTypeByName()
def testAddSerializedFile(self):
- db = descriptor_database.DescriptorDatabase()
- self.pool = descriptor_pool.DescriptorPool(db)
+ self.pool = descriptor_pool.DescriptorPool()
self.pool.AddSerializedFile(self.factory_test1_fd.SerializeToString())
self.pool.AddSerializedFile(self.factory_test2_fd.SerializeToString())
self.testFindMessageTypeByName()
@@ -274,6 +310,56 @@ class DescriptorPoolTest(unittest.TestCase):
'google/protobuf/internal/descriptor_pool_test1.proto')
_CheckDefaultValue(file_descriptor)
+ def testDefaultValueForCustomMessages(self):
+ """Check the value returned by non-existent fields."""
+ def _CheckValueAndType(value, expected_value, expected_type):
+ self.assertEqual(value, expected_value)
+ self.assertIsInstance(value, expected_type)
+
+ def _CheckDefaultValues(msg):
+ try:
+ int64 = long
+ except NameError: # Python3
+ int64 = int
+ try:
+ unicode_type = unicode
+ except NameError: # Python3
+ unicode_type = str
+ _CheckValueAndType(msg.optional_int32, 0, int)
+ _CheckValueAndType(msg.optional_uint64, 0, (int64, int))
+ _CheckValueAndType(msg.optional_float, 0, (float, int))
+ _CheckValueAndType(msg.optional_double, 0, (float, int))
+ _CheckValueAndType(msg.optional_bool, False, bool)
+ _CheckValueAndType(msg.optional_string, u'', unicode_type)
+ _CheckValueAndType(msg.optional_bytes, b'', bytes)
+ _CheckValueAndType(msg.optional_nested_enum, msg.FOO, int)
+ # First for the generated message
+ _CheckDefaultValues(unittest_pb2.TestAllTypes())
+ # Then for a message built with from the DescriptorPool.
+ pool = descriptor_pool.DescriptorPool()
+ pool.Add(descriptor_pb2.FileDescriptorProto.FromString(
+ unittest_import_public_pb2.DESCRIPTOR.serialized_pb))
+ pool.Add(descriptor_pb2.FileDescriptorProto.FromString(
+ unittest_import_pb2.DESCRIPTOR.serialized_pb))
+ pool.Add(descriptor_pb2.FileDescriptorProto.FromString(
+ unittest_pb2.DESCRIPTOR.serialized_pb))
+ message_class = message_factory.MessageFactory(pool).GetPrototype(
+ pool.FindMessageTypeByName(
+ unittest_pb2.TestAllTypes.DESCRIPTOR.full_name))
+ _CheckDefaultValues(message_class())
+
+
+@unittest.skipIf(api_implementation.Type() != 'cpp',
+ 'explicit tests of the C++ implementation')
+class CppDescriptorPoolTest(DescriptorPoolTest):
+ # TODO(amauryfa): remove when descriptor_pool.DescriptorPool() creates true
+ # C++ descriptor pool object for C++ implementation.
+
+ def CreatePool(self):
+ # pylint: disable=g-import-not-at-top
+ from google.protobuf.pyext import _message
+ return _message.DescriptorPool()
+
class ProtoFile(object):
@@ -468,6 +554,8 @@ class AddDescriptorTest(unittest.TestCase):
pool.FindFileContainingSymbol(
prefix + 'protobuf_unittest.TestAllTypes.NestedMessage').name)
+ @unittest.skipIf(api_implementation.Type() == 'cpp',
+ 'With the cpp implementation, Add() must be called first')
def testMessage(self):
self._TestMessage('')
self._TestMessage('.')
@@ -502,10 +590,14 @@ class AddDescriptorTest(unittest.TestCase):
pool.FindFileContainingSymbol(
prefix + 'protobuf_unittest.TestAllTypes.NestedEnum').name)
+ @unittest.skipIf(api_implementation.Type() == 'cpp',
+ 'With the cpp implementation, Add() must be called first')
def testEnum(self):
self._TestEnum('')
self._TestEnum('.')
+ @unittest.skipIf(api_implementation.Type() == 'cpp',
+ 'With the cpp implementation, Add() must be called first')
def testFile(self):
pool = descriptor_pool.DescriptorPool()
pool.AddFileDescriptor(unittest_pb2.DESCRIPTOR)
@@ -520,6 +612,76 @@ class AddDescriptorTest(unittest.TestCase):
pool.FindFileContainingSymbol(
'protobuf_unittest.TestAllTypes')
+ def _GetDescriptorPoolClass(self):
+ # Test with both implementations of descriptor pools.
+ if api_implementation.Type() == 'cpp':
+ # pylint: disable=g-import-not-at-top
+ from google.protobuf.pyext import _message
+ return _message.DescriptorPool
+ else:
+ return descriptor_pool.DescriptorPool
+
+ def testEmptyDescriptorPool(self):
+ # Check that an empty DescriptorPool() contains no message.
+ pool = self._GetDescriptorPoolClass()()
+ proto_file_name = descriptor_pb2.DESCRIPTOR.name
+ self.assertRaises(KeyError, pool.FindFileByName, proto_file_name)
+ # Add the above file to the pool
+ file_descriptor = descriptor_pb2.FileDescriptorProto()
+ descriptor_pb2.DESCRIPTOR.CopyToProto(file_descriptor)
+ pool.Add(file_descriptor)
+ # Now it exists.
+ self.assertTrue(pool.FindFileByName(proto_file_name))
+
+ def testCustomDescriptorPool(self):
+ # Create a new pool, and add a file descriptor.
+ pool = self._GetDescriptorPoolClass()()
+ file_desc = descriptor_pb2.FileDescriptorProto(
+ name='some/file.proto', package='package')
+ file_desc.message_type.add(name='Message')
+ pool.Add(file_desc)
+ self.assertEqual(pool.FindFileByName('some/file.proto').name,
+ 'some/file.proto')
+ self.assertEqual(pool.FindMessageTypeByName('package.Message').name,
+ 'Message')
+
+
+@unittest.skipIf(
+ api_implementation.Type() != 'cpp',
+ 'default_pool is only supported by the C++ implementation')
+class DefaultPoolTest(unittest.TestCase):
+
+ def testFindMethods(self):
+ # pylint: disable=g-import-not-at-top
+ from google.protobuf.pyext import _message
+ pool = _message.default_pool
+ self.assertIs(
+ pool.FindFileByName('google/protobuf/unittest.proto'),
+ unittest_pb2.DESCRIPTOR)
+ self.assertIs(
+ pool.FindMessageTypeByName('protobuf_unittest.TestAllTypes'),
+ unittest_pb2.TestAllTypes.DESCRIPTOR)
+ self.assertIs(
+ pool.FindFieldByName('protobuf_unittest.TestAllTypes.optional_int32'),
+ unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name['optional_int32'])
+ self.assertIs(
+ pool.FindExtensionByName('protobuf_unittest.optional_int32_extension'),
+ unittest_pb2.DESCRIPTOR.extensions_by_name['optional_int32_extension'])
+ self.assertIs(
+ pool.FindEnumTypeByName('protobuf_unittest.ForeignEnum'),
+ unittest_pb2.ForeignEnum.DESCRIPTOR)
+ self.assertIs(
+ pool.FindOneofByName('protobuf_unittest.TestAllTypes.oneof_field'),
+ unittest_pb2.TestAllTypes.DESCRIPTOR.oneofs_by_name['oneof_field'])
+
+ def testAddFileDescriptor(self):
+ # pylint: disable=g-import-not-at-top
+ from google.protobuf.pyext import _message
+ pool = _message.default_pool
+ file_desc = descriptor_pb2.FileDescriptorProto(name='some/file.proto')
+ pool.Add(file_desc)
+ pool.AddSerializedFile(file_desc.SerializeToString())
+
TEST1_FILE = ProtoFile(
'google/protobuf/internal/descriptor_pool_test1.proto',
diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py
index 99afee63..fee09a56 100755
--- a/python/google/protobuf/internal/descriptor_test.py
+++ b/python/google/protobuf/internal/descriptor_test.py
@@ -47,6 +47,7 @@ from google.protobuf import descriptor_pb2
from google.protobuf.internal import api_implementation
from google.protobuf.internal import test_util
from google.protobuf import descriptor
+from google.protobuf import descriptor_pool
from google.protobuf import symbol_database
from google.protobuf import text_format
@@ -75,9 +76,9 @@ class DescriptorTest(unittest.TestCase):
enum_proto.value.add(name='FOREIGN_BAR', number=5)
enum_proto.value.add(name='FOREIGN_BAZ', number=6)
- descriptor_pool = symbol_database.Default().pool
- descriptor_pool.Add(file_proto)
- self.my_file = descriptor_pool.FindFileByName(file_proto.name)
+ self.pool = self.GetDescriptorPool()
+ self.pool.Add(file_proto)
+ self.my_file = self.pool.FindFileByName(file_proto.name)
self.my_message = self.my_file.message_types_by_name[message_proto.name]
self.my_enum = self.my_message.enum_types_by_name[enum_proto.name]
@@ -97,6 +98,9 @@ class DescriptorTest(unittest.TestCase):
self.my_method
])
+ def GetDescriptorPool(self):
+ return symbol_database.Default().pool
+
def testEnumValueName(self):
self.assertEqual(self.my_message.EnumValueName('ForeignEnum', 4),
'FOREIGN_FOO')
@@ -393,6 +397,9 @@ class DescriptorTest(unittest.TestCase):
def testFileDescriptor(self):
self.assertEqual(self.my_file.name, 'some/filename/some.proto')
self.assertEqual(self.my_file.package, 'protobuf_unittest')
+ self.assertEqual(self.my_file.pool, self.pool)
+ # Generated modules also belong to the default pool.
+ self.assertEqual(unittest_pb2.DESCRIPTOR.pool, descriptor_pool.Default())
@unittest.skipIf(
api_implementation.Type() != 'cpp' or api_implementation.Version() != 2,
@@ -407,6 +414,13 @@ class DescriptorTest(unittest.TestCase):
message_descriptor.fields.append(None)
+class NewDescriptorTest(DescriptorTest):
+ """Redo the same tests as above, but with a separate DescriptorPool."""
+
+ def GetDescriptorPool(self):
+ return descriptor_pool.DescriptorPool()
+
+
class GeneratedDescriptorTest(unittest.TestCase):
"""Tests for the properties of descriptors in generated code."""
diff --git a/python/google/protobuf/internal/json_format_test.py b/python/google/protobuf/internal/json_format_test.py
index 69197865..be3ad11a 100644
--- a/python/google/protobuf/internal/json_format_test.py
+++ b/python/google/protobuf/internal/json_format_test.py
@@ -42,6 +42,7 @@ try:
import unittest2 as unittest
except ImportError:
import unittest
+from google.protobuf.internal import well_known_types
from google.protobuf import json_format
from google.protobuf.util import json_format_proto3_pb2
@@ -269,15 +270,15 @@ class JsonFormatTest(JsonFormatBase):
'}'))
parsed_message = json_format_proto3_pb2.TestTimestamp()
self.CheckParseBack(message, parsed_message)
- text = (r'{"value": "1972-01-01T01:00:00.01+08:00",'
+ text = (r'{"value": "1970-01-01T00:00:00.01+08:00",'
r'"repeatedValue":['
- r' "1972-01-01T01:00:00.01+08:30",'
- r' "1972-01-01T01:00:00.01-01:23"]}')
+ r' "1970-01-01T00:00:00.01+08:30",'
+ r' "1970-01-01T00:00:00.01-01:23"]}')
json_format.Parse(text, parsed_message)
- self.assertEqual(parsed_message.value.seconds, 63104400)
+ self.assertEqual(parsed_message.value.seconds, -8 * 3600)
self.assertEqual(parsed_message.value.nanos, 10000000)
- self.assertEqual(parsed_message.repeated_value[0].seconds, 63106200)
- self.assertEqual(parsed_message.repeated_value[1].seconds, 63070620)
+ self.assertEqual(parsed_message.repeated_value[0].seconds, -8.5 * 3600)
+ self.assertEqual(parsed_message.repeated_value[1].seconds, 3600 + 23 * 60)
def testDurationMessage(self):
message = json_format_proto3_pb2.TestDuration()
@@ -389,7 +390,7 @@ class JsonFormatTest(JsonFormatBase):
def testParseEmptyText(self):
self.CheckError('',
- r'Failed to load JSON: (Expecting value)|(No JSON)')
+ r'Failed to load JSON: (Expecting value)|(No JSON).')
def testParseBadEnumValue(self):
self.CheckError(
@@ -414,7 +415,7 @@ class JsonFormatTest(JsonFormatBase):
if sys.version_info < (2, 7):
return
self.CheckError('{"int32Value": 1,\n"int32Value":2}',
- 'Failed to load JSON: duplicate key int32Value')
+ 'Failed to load JSON: duplicate key int32Value.')
def testInvalidBoolValue(self):
self.CheckError('{"boolValue": 1}',
@@ -431,39 +432,43 @@ class JsonFormatTest(JsonFormatBase):
json_format.Parse, text, message)
self.CheckError('{"int32Value": 012345}',
(r'Failed to load JSON: Expecting \'?,\'? delimiter: '
- r'line 1'))
+ r'line 1.'))
self.CheckError('{"int32Value": 1.0}',
'Failed to parse int32Value field: '
- 'Couldn\'t parse integer: 1.0')
+ 'Couldn\'t parse integer: 1.0.')
self.CheckError('{"int32Value": " 1 "}',
'Failed to parse int32Value field: '
- 'Couldn\'t parse integer: " 1 "')
+ 'Couldn\'t parse integer: " 1 ".')
+ self.CheckError('{"int32Value": "1 "}',
+ 'Failed to parse int32Value field: '
+ 'Couldn\'t parse integer: "1 ".')
self.CheckError('{"int32Value": 12345678901234567890}',
'Failed to parse int32Value field: Value out of range: '
- '12345678901234567890')
+ '12345678901234567890.')
self.CheckError('{"int32Value": 1e5}',
'Failed to parse int32Value field: '
- 'Couldn\'t parse integer: 100000.0')
+ 'Couldn\'t parse integer: 100000.0.')
self.CheckError('{"uint32Value": -1}',
- 'Failed to parse uint32Value field: Value out of range: -1')
+ 'Failed to parse uint32Value field: '
+ 'Value out of range: -1.')
def testInvalidFloatValue(self):
self.CheckError('{"floatValue": "nan"}',
'Failed to parse floatValue field: Couldn\'t '
- 'parse float "nan", use "NaN" instead')
+ 'parse float "nan", use "NaN" instead.')
def testInvalidBytesValue(self):
self.CheckError('{"bytesValue": "AQI"}',
- 'Failed to parse bytesValue field: Incorrect padding')
+ 'Failed to parse bytesValue field: Incorrect padding.')
self.CheckError('{"bytesValue": "AQI*"}',
- 'Failed to parse bytesValue field: Incorrect padding')
+ 'Failed to parse bytesValue field: Incorrect padding.')
def testInvalidMap(self):
message = json_format_proto3_pb2.TestMap()
text = '{"int32Map": {"null": 2, "2": 3}}'
self.assertRaisesRegexp(
json_format.ParseError,
- 'Failed to parse int32Map field: Couldn\'t parse integer: "null"',
+ 'Failed to parse int32Map field: invalid literal',
json_format.Parse, text, message)
text = '{"int32Map": {1: 2, "2": 3}}'
self.assertRaisesRegexp(
@@ -474,7 +479,7 @@ class JsonFormatTest(JsonFormatBase):
text = '{"boolMap": {"null": 1}}'
self.assertRaisesRegexp(
json_format.ParseError,
- 'Failed to parse boolMap field: Expect "true" or "false", not null.',
+ 'Failed to parse boolMap field: Expected "true" or "false", not null.',
json_format.Parse, text, message)
if sys.version_info < (2, 7):
return
@@ -490,30 +495,29 @@ class JsonFormatTest(JsonFormatBase):
self.assertRaisesRegexp(
json_format.ParseError,
'time data \'10000-01-01T00:00:00\' does not match'
- ' format \'%Y-%m-%dT%H:%M:%S\'',
+ ' format \'%Y-%m-%dT%H:%M:%S\'.',
json_format.Parse, text, message)
text = '{"value": "1970-01-01T00:00:00.0123456789012Z"}'
self.assertRaisesRegexp(
- json_format.ParseError,
- 'Failed to parse value field: Failed to parse Timestamp: '
+ well_known_types.ParseError,
'nanos 0123456789012 more than 9 fractional digits.',
json_format.Parse, text, message)
text = '{"value": "1972-01-01T01:00:00.01+08"}'
self.assertRaisesRegexp(
- json_format.ParseError,
- (r'Failed to parse value field: Invalid timezone offset value: \+08'),
+ well_known_types.ParseError,
+ (r'Invalid timezone offset value: \+08.'),
json_format.Parse, text, message)
# Time smaller than minimum time.
text = '{"value": "0000-01-01T00:00:00Z"}'
self.assertRaisesRegexp(
json_format.ParseError,
- 'Failed to parse value field: year is out of range',
+ 'Failed to parse value field: year is out of range.',
json_format.Parse, text, message)
# Time bigger than maxinum time.
message.value.seconds = 253402300800
self.assertRaisesRegexp(
- json_format.SerializeToJsonError,
- 'Failed to serialize value field: year is out of range',
+ OverflowError,
+ 'date value out of range',
json_format.MessageToJson, message)
def testInvalidOneof(self):
diff --git a/python/google/protobuf/internal/message_factory_test.py b/python/google/protobuf/internal/message_factory_test.py
index d760b898..2fbe5ea7 100644
--- a/python/google/protobuf/internal/message_factory_test.py
+++ b/python/google/protobuf/internal/message_factory_test.py
@@ -45,6 +45,7 @@ from google.protobuf import descriptor_database
from google.protobuf import descriptor_pool
from google.protobuf import message_factory
+
class MessageFactoryTest(unittest.TestCase):
def setUp(self):
@@ -104,8 +105,8 @@ class MessageFactoryTest(unittest.TestCase):
def testGetMessages(self):
# performed twice because multiple calls with the same input must be allowed
for _ in range(2):
- messages = message_factory.GetMessages([self.factory_test2_fd,
- self.factory_test1_fd])
+ messages = message_factory.GetMessages([self.factory_test1_fd,
+ self.factory_test2_fd])
self.assertTrue(
set(['google.protobuf.python.internal.Factory2Message',
'google.protobuf.python.internal.Factory1Message'],
@@ -116,7 +117,7 @@ class MessageFactoryTest(unittest.TestCase):
set(['google.protobuf.python.internal.Factory2Message.one_more_field',
'google.protobuf.python.internal.another_field'],
).issubset(
- set(messages['google.protobuf.python.internal.Factory1Message']
+ set(messages['google.protobuf.python.internal.Factory1Message']
._extensions_by_name.keys())))
factory_msg1 = messages['google.protobuf.python.internal.Factory1Message']
msg1 = messages['google.protobuf.python.internal.Factory1Message']()
diff --git a/python/google/protobuf/internal/message_set_extensions.proto b/python/google/protobuf/internal/message_set_extensions.proto
index 702c8d07..14e5f193 100644
--- a/python/google/protobuf/internal/message_set_extensions.proto
+++ b/python/google/protobuf/internal/message_set_extensions.proto
@@ -54,6 +54,14 @@ message TestMessageSetExtension2 {
optional string str = 25;
}
+message TestMessageSetExtension3 {
+ optional string text = 35;
+}
+
+extend TestMessageSet {
+ optional TestMessageSetExtension3 message_set_extension3 = 98418655;
+}
+
// This message was used to generate
// //net/proto2/python/internal/testdata/message_set_message, but is commented
// out since it must not actually exist in code, to simulate an "unknown"
diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py
index 13c3caa6..d03f2d25 100755
--- a/python/google/protobuf/internal/message_test.py
+++ b/python/google/protobuf/internal/message_test.py
@@ -60,6 +60,7 @@ from google.protobuf.internal import _parameterized
from google.protobuf import map_unittest_pb2
from google.protobuf import unittest_pb2
from google.protobuf import unittest_proto3_arena_pb2
+from google.protobuf.internal import any_test_pb2
from google.protobuf.internal import api_implementation
from google.protobuf.internal import packed_field_test_pb2
from google.protobuf.internal import test_util
@@ -1279,12 +1280,13 @@ class Proto3Test(unittest.TestCase):
self.assertIsInstance(msg.map_string_string['abc'], six.text_type)
- # Accessing an unset key still throws TypeError of the type of the key
+ # Accessing an unset key still throws TypeError if the type of the key
# is incorrect.
with self.assertRaises(TypeError):
msg.map_string_string[123]
- self.assertFalse(123 in msg.map_string_string)
+ with self.assertRaises(TypeError):
+ 123 in msg.map_string_string
def testMapGet(self):
# Need to test that get() properly returns the default, even though the dict
@@ -1591,31 +1593,49 @@ class Proto3Test(unittest.TestCase):
# For the C++ implementation this tests the correctness of
# ScalarMapContainer::Release()
msg = map_unittest_pb2.TestMap()
- map = msg.map_int32_int32
+ int32_map = msg.map_int32_int32
- map[2] = 4
- map[3] = 6
- map[4] = 8
+ int32_map[2] = 4
+ int32_map[3] = 6
+ int32_map[4] = 8
msg.ClearField('map_int32_int32')
+ self.assertEqual(b'', msg.SerializeToString())
matching_dict = {2: 4, 3: 6, 4: 8}
- self.assertMapIterEquals(map.items(), matching_dict)
+ self.assertMapIterEquals(int32_map.items(), matching_dict)
- def testMapIterValidAfterFieldCleared(self):
- # Map iterator needs to work even if field is cleared.
+ def testMessageMapValidAfterFieldCleared(self):
+ # Map needs to work even if field is cleared.
# For the C++ implementation this tests the correctness of
# ScalarMapContainer::Release()
msg = map_unittest_pb2.TestMap()
+ int32_foreign_message = msg.map_int32_foreign_message
- msg.map_int32_int32[2] = 4
- msg.map_int32_int32[3] = 6
- msg.map_int32_int32[4] = 8
+ int32_foreign_message[2].c = 5
- it = msg.map_int32_int32.items()
+ msg.ClearField('map_int32_foreign_message')
+ self.assertEqual(b'', msg.SerializeToString())
+ self.assertTrue(2 in int32_foreign_message.keys())
+
+ def testMapIterInvalidatedByClearField(self):
+ # Map iterator is invalidated when field is cleared.
+ # But this case does need to not crash the interpreter.
+ # For the C++ implementation this tests the correctness of
+ # ScalarMapContainer::Release()
+ msg = map_unittest_pb2.TestMap()
+
+ it = iter(msg.map_int32_int32)
msg.ClearField('map_int32_int32')
- matching_dict = {2: 4, 3: 6, 4: 8}
- self.assertMapIterEquals(it, matching_dict)
+ with self.assertRaises(RuntimeError):
+ for _ in it:
+ pass
+
+ it = iter(msg.map_int32_foreign_message)
+ msg.ClearField('map_int32_foreign_message')
+ with self.assertRaises(RuntimeError):
+ for _ in it:
+ pass
def testMapDelete(self):
msg = map_unittest_pb2.TestMap()
@@ -1646,6 +1666,37 @@ class Proto3Test(unittest.TestCase):
msg.map_string_foreign_message['foo'].c = 5
self.assertEqual(0, len(msg.FindInitializationErrors()))
+ def testAnyMessage(self):
+ # Creates and sets message.
+ msg = any_test_pb2.TestAny()
+ msg_descriptor = msg.DESCRIPTOR
+ all_types = unittest_pb2.TestAllTypes()
+ all_descriptor = all_types.DESCRIPTOR
+ all_types.repeated_string.append(u'\u00fc\ua71f')
+ # Packs to Any.
+ msg.value.Pack(all_types)
+ self.assertEqual(msg.value.type_url,
+ 'type.googleapis.com/%s' % all_descriptor.full_name)
+ self.assertEqual(msg.value.value,
+ all_types.SerializeToString())
+ # Tests Is() method.
+ self.assertTrue(msg.value.Is(all_descriptor))
+ self.assertFalse(msg.value.Is(msg_descriptor))
+ # Unpacks Any.
+ unpacked_message = unittest_pb2.TestAllTypes()
+ self.assertTrue(msg.value.Unpack(unpacked_message))
+ self.assertEqual(all_types, unpacked_message)
+ # Unpacks to different type.
+ self.assertFalse(msg.value.Unpack(msg))
+ # Only Any messages have Pack method.
+ try:
+ msg.Pack(all_types)
+ except AttributeError:
+ pass
+ else:
+ raise AttributeError('%s should not have Pack method.' %
+ msg_descriptor.full_name)
+
class ValidTypeNamesTest(unittest.TestCase):
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py
index 2b87f704..87f60666 100755
--- a/python/google/protobuf/internal/python_message.py
+++ b/python/google/protobuf/internal/python_message.py
@@ -65,6 +65,7 @@ from google.protobuf.internal import encoder
from google.protobuf.internal import enum_type_wrapper
from google.protobuf.internal import message_listener as message_listener_mod
from google.protobuf.internal import type_checkers
+from google.protobuf.internal import well_known_types
from google.protobuf.internal import wire_format
from google.protobuf import descriptor as descriptor_mod
from google.protobuf import message as message_mod
@@ -72,6 +73,7 @@ from google.protobuf import symbol_database
from google.protobuf import text_format
_FieldDescriptor = descriptor_mod.FieldDescriptor
+_AnyFullTypeName = 'google.protobuf.Any'
class GeneratedProtocolMessageType(type):
@@ -127,6 +129,8 @@ class GeneratedProtocolMessageType(type):
Newly-allocated class.
"""
descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
+ if descriptor.full_name in well_known_types.WKTBASES:
+ bases += (well_known_types.WKTBASES[descriptor.full_name],)
_AddClassAttributesForNestedExtensions(descriptor, dictionary)
_AddSlots(descriptor, dictionary)
@@ -261,7 +265,6 @@ def _IsMessageSetExtension(field):
field.containing_type.has_options and
field.containing_type.GetOptions().message_set_wire_format and
field.type == _FieldDescriptor.TYPE_MESSAGE and
- field.message_type == field.extension_scope and
field.label == _FieldDescriptor.LABEL_OPTIONAL)
@@ -543,7 +546,8 @@ def _GetFieldByName(message_descriptor, field_name):
try:
return message_descriptor.fields_by_name[field_name]
except KeyError:
- raise ValueError('Protocol message has no "%s" field.' % field_name)
+ raise ValueError('Protocol message %s has no "%s" field.' %
+ (message_descriptor.name, field_name))
def _AddPropertiesForFields(descriptor, cls):
@@ -848,9 +852,15 @@ def _AddClearFieldMethod(message_descriptor, cls):
else:
return
except KeyError:
- raise ValueError('Protocol message has no "%s" field.' % field_name)
+ raise ValueError('Protocol message %s() has no "%s" field.' %
+ (message_descriptor.name, field_name))
if field in self._fields:
+ # To match the C++ implementation, we need to invalidate iterators
+ # for map fields when ClearField() happens.
+ if hasattr(self._fields[field], 'InvalidateIterators'):
+ self._fields[field].InvalidateIterators()
+
# Note: If the field is a sub-message, its listener will still point
# at us. That's fine, because the worst than can happen is that it
# will call _Modified() and invalidate our byte size. Big deal.
@@ -904,7 +914,19 @@ def _AddHasExtensionMethod(cls):
return extension_handle in self._fields
cls.HasExtension = HasExtension
-def _UnpackAny(msg):
+def _InternalUnpackAny(msg):
+ """Unpacks Any message and returns the unpacked message.
+
+ This internal method is differnt from public Any Unpack method which takes
+ the target message as argument. _InternalUnpackAny method does not have
+ target message type and need to find the message type in descriptor pool.
+
+ Args:
+ msg: An Any message to be unpacked.
+
+ Returns:
+ The unpacked message.
+ """
type_url = msg.type_url
db = symbol_database.Default()
@@ -935,9 +957,9 @@ def _AddEqualsMethod(message_descriptor, cls):
if self is other:
return True
- if self.DESCRIPTOR.full_name == "google.protobuf.Any":
- any_a = _UnpackAny(self)
- any_b = _UnpackAny(other)
+ if self.DESCRIPTOR.full_name == _AnyFullTypeName:
+ any_a = _InternalUnpackAny(self)
+ any_b = _InternalUnpackAny(other)
if any_a and any_b:
return any_a == any_b
@@ -962,6 +984,13 @@ def _AddStrMethod(message_descriptor, cls):
cls.__str__ = __str__
+def _AddReprMethod(message_descriptor, cls):
+ """Helper for _AddMessageMethods()."""
+ def __repr__(self):
+ return text_format.MessageToString(self)
+ cls.__repr__ = __repr__
+
+
def _AddUnicodeMethod(unused_message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
@@ -1270,6 +1299,7 @@ def _AddMessageMethods(message_descriptor, cls):
_AddClearMethod(message_descriptor, cls)
_AddEqualsMethod(message_descriptor, cls)
_AddStrMethod(message_descriptor, cls)
+ _AddReprMethod(message_descriptor, cls)
_AddUnicodeMethod(message_descriptor, cls)
_AddSetListenerMethod(cls)
_AddByteSizeMethod(message_descriptor, cls)
@@ -1280,6 +1310,7 @@ def _AddMessageMethods(message_descriptor, cls):
_AddMergeFromMethod(cls)
_AddWhichOneofMethod(message_descriptor, cls)
+
def _AddPrivateHelperMethods(message_descriptor, cls):
"""Adds implementation of private helper methods to cls."""
diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py
index 8621d61e..752f2f5d 100755
--- a/python/google/protobuf/internal/reflection_test.py
+++ b/python/google/protobuf/internal/reflection_test.py
@@ -2394,8 +2394,10 @@ class SerializationTest(unittest.TestCase):
extension_message2 = message_set_extensions_pb2.TestMessageSetExtension2
extension1 = extension_message1.message_set_extension
extension2 = extension_message2.message_set_extension
+ extension3 = message_set_extensions_pb2.message_set_extension3
proto.Extensions[extension1].i = 123
proto.Extensions[extension2].str = 'foo'
+ proto.Extensions[extension3].text = 'bar'
# Serialize using the MessageSet wire format (this is specified in the
# .proto file).
@@ -2407,7 +2409,7 @@ class SerializationTest(unittest.TestCase):
self.assertEqual(
len(serialized),
raw.MergeFromString(serialized))
- self.assertEqual(2, len(raw.item))
+ self.assertEqual(3, len(raw.item))
message1 = message_set_extensions_pb2.TestMessageSetExtension1()
self.assertEqual(
@@ -2421,6 +2423,12 @@ class SerializationTest(unittest.TestCase):
message2.MergeFromString(raw.item[1].message))
self.assertEqual('foo', message2.str)
+ message3 = message_set_extensions_pb2.TestMessageSetExtension3()
+ self.assertEqual(
+ len(raw.item[2].message),
+ message3.MergeFromString(raw.item[2].message))
+ self.assertEqual('bar', message3.text)
+
# Deserialize using the MessageSet wire format.
proto2 = message_set_extensions_pb2.TestMessageSet()
self.assertEqual(
@@ -2428,6 +2436,7 @@ class SerializationTest(unittest.TestCase):
proto2.MergeFromString(serialized))
self.assertEqual(123, proto2.Extensions[extension1].i)
self.assertEqual('foo', proto2.Extensions[extension2].str)
+ self.assertEqual('bar', proto2.Extensions[extension3].text)
# Check byte size.
self.assertEqual(proto2.ByteSize(), len(serialized))
@@ -2757,9 +2766,10 @@ class SerializationTest(unittest.TestCase):
def testInitArgsUnknownFieldName(self):
def InitalizeEmptyMessageWithExtraKeywordArg():
unused_proto = unittest_pb2.TestEmptyMessage(unknown='unknown')
- self._CheckRaises(ValueError,
- InitalizeEmptyMessageWithExtraKeywordArg,
- 'Protocol message has no "unknown" field.')
+ self._CheckRaises(
+ ValueError,
+ InitalizeEmptyMessageWithExtraKeywordArg,
+ 'Protocol message TestEmptyMessage has no "unknown" field.')
def testInitRequiredKwargs(self):
proto = unittest_pb2.TestRequired(a=1, b=1, c=1)
diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py
index cca0ee63..0e14556c 100755
--- a/python/google/protobuf/internal/text_format_test.py
+++ b/python/google/protobuf/internal/text_format_test.py
@@ -51,8 +51,22 @@ from google.protobuf import unittest_pb2
from google.protobuf import unittest_proto3_arena_pb2
from google.protobuf.internal import api_implementation
from google.protobuf.internal import test_util
+from google.protobuf.internal import message_set_extensions_pb2
from google.protobuf import text_format
+
+# Low-level nuts-n-bolts tests.
+class SimpleTextFormatTests(unittest.TestCase):
+
+ # The members of _QUOTES are formatted into a regexp template that
+ # expects single characters. Therefore it's an error (in addition to being
+ # non-sensical in the first place) to try to specify a "quote mark" that is
+ # more than one character.
+ def TestQuoteMarksAreSingleChars(self):
+ for quote in text_format._QUOTES:
+ self.assertEqual(1, len(quote))
+
+
# Base class with some common functionality.
class TextFormatBase(unittest.TestCase):
@@ -287,6 +301,19 @@ class TextFormatTest(TextFormatBase):
self.assertEqual(u'one', message.repeated_string[0])
self.assertEqual(u'two', message.repeated_string[1])
+ def testParseRepeatedScalarShortFormat(self, message_module):
+ message = message_module.TestAllTypes()
+ text = ('repeated_int64: [100, 200];\n'
+ 'repeated_int64: 300,\n'
+ 'repeated_string: ["one", "two"];\n')
+ text_format.Parse(text, message)
+
+ self.assertEqual(100, message.repeated_int64[0])
+ self.assertEqual(200, message.repeated_int64[1])
+ self.assertEqual(300, message.repeated_int64[2])
+ self.assertEqual(u'one', message.repeated_string[0])
+ self.assertEqual(u'two', message.repeated_string[1])
+
def testParseEmptyText(self, message_module):
message = message_module.TestAllTypes()
text = ''
@@ -301,7 +328,7 @@ class TextFormatTest(TextFormatBase):
def testParseSingleWord(self, message_module):
message = message_module.TestAllTypes()
text = 'foo'
- six.assertRaisesRegex(self,
+ six.assertRaisesRegex(self,
text_format.ParseError,
(r'1:1 : Message type "\w+.TestAllTypes" has no field named '
r'"foo".'),
@@ -310,7 +337,7 @@ class TextFormatTest(TextFormatBase):
def testParseUnknownField(self, message_module):
message = message_module.TestAllTypes()
text = 'unknown_field: 8\n'
- six.assertRaisesRegex(self,
+ six.assertRaisesRegex(self,
text_format.ParseError,
(r'1:1 : Message type "\w+.TestAllTypes" has no field named '
r'"unknown_field".'),
@@ -319,7 +346,7 @@ class TextFormatTest(TextFormatBase):
def testParseBadEnumValue(self, message_module):
message = message_module.TestAllTypes()
text = 'optional_nested_enum: BARR'
- six.assertRaisesRegex(self,
+ six.assertRaisesRegex(self,
text_format.ParseError,
(r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" '
r'has no value named BARR.'),
@@ -327,7 +354,7 @@ class TextFormatTest(TextFormatBase):
message = message_module.TestAllTypes()
text = 'optional_nested_enum: 100'
- six.assertRaisesRegex(self,
+ six.assertRaisesRegex(self,
text_format.ParseError,
(r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" '
r'has no value with number 100.'),
@@ -336,7 +363,7 @@ class TextFormatTest(TextFormatBase):
def testParseBadIntValue(self, message_module):
message = message_module.TestAllTypes()
text = 'optional_int32: bork'
- six.assertRaisesRegex(self,
+ six.assertRaisesRegex(self,
text_format.ParseError,
('1:17 : Couldn\'t parse integer: bork'),
text_format.Parse, text, message)
@@ -553,6 +580,15 @@ class Proto2Tests(TextFormatBase):
' }\n'
'}\n')
+ message = message_set_extensions_pb2.TestMessageSet()
+ ext = message_set_extensions_pb2.message_set_extension3
+ message.Extensions[ext].text = 'bar'
+ self.CompareToGoldenText(
+ text_format.MessageToString(message),
+ '[google.protobuf.internal.TestMessageSetExtension3] {\n'
+ ' text: \"bar\"\n'
+ '}\n')
+
def testPrintMessageSetAsOneLine(self):
message = unittest_mset_pb2.TestMessageSetContainer()
ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
@@ -627,15 +663,125 @@ class Proto2Tests(TextFormatBase):
text_format.Parse(ascii_text, parsed_message)
self.assertEqual(message, parsed_message)
+ def testParseAllowedUnknownExtension(self):
+ # Skip over unknown extension correctly.
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ text = ('message_set {\n'
+ ' [unknown_extension] {\n'
+ ' i: 23\n'
+ ' [nested_unknown_ext]: {\n'
+ ' i: 23\n'
+ ' test: "test_string"\n'
+ ' floaty_float: -0.315\n'
+ ' num: -inf\n'
+ ' multiline_str: "abc"\n'
+ ' "def"\n'
+ ' "xyz."\n'
+ ' [nested_unknown_ext]: <\n'
+ ' i: 23\n'
+ ' i: 24\n'
+ ' pointfloat: .3\n'
+ ' test: "test_string"\n'
+ ' floaty_float: -0.315\n'
+ ' num: -inf\n'
+ ' long_string: "test" "test2" \n'
+ ' >\n'
+ ' }\n'
+ ' }\n'
+ ' [unknown_extension]: 5\n'
+ '}\n')
+ text_format.Parse(text, message, allow_unknown_extension=True)
+ golden = 'message_set {\n}\n'
+ self.CompareToGoldenText(text_format.MessageToString(message), golden)
+
+ # Catch parse errors in unknown extension.
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ malformed = ('message_set {\n'
+ ' [unknown_extension] {\n'
+ ' i:\n' # Missing value.
+ ' }\n'
+ '}\n')
+ six.assertRaisesRegex(self,
+ text_format.ParseError,
+ 'Invalid field value: }',
+ text_format.Parse, malformed, message,
+ allow_unknown_extension=True)
+
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ malformed = ('message_set {\n'
+ ' [unknown_extension] {\n'
+ ' str: "malformed string\n' # Missing closing quote.
+ ' }\n'
+ '}\n')
+ six.assertRaisesRegex(self,
+ text_format.ParseError,
+ 'Invalid field value: "',
+ text_format.Parse, malformed, message,
+ allow_unknown_extension=True)
+
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ malformed = ('message_set {\n'
+ ' [unknown_extension] {\n'
+ ' str: "malformed\n multiline\n string\n'
+ ' }\n'
+ '}\n')
+ six.assertRaisesRegex(self,
+ text_format.ParseError,
+ 'Invalid field value: "',
+ text_format.Parse, malformed, message,
+ allow_unknown_extension=True)
+
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ malformed = ('message_set {\n'
+ ' [malformed_extension] <\n'
+ ' i: -5\n'
+ ' \n' # Missing '>' here.
+ '}\n')
+ six.assertRaisesRegex(self,
+ text_format.ParseError,
+ '5:1 : Expected ">".',
+ text_format.Parse, malformed, message,
+ allow_unknown_extension=True)
+
+ # Don't allow unknown fields with allow_unknown_extension=True.
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ malformed = ('message_set {\n'
+ ' unknown_field: true\n'
+ ' \n' # Missing '>' here.
+ '}\n')
+ six.assertRaisesRegex(self,
+ text_format.ParseError,
+ ('2:3 : Message type '
+ '"proto2_wireformat_unittest.TestMessageSet" has no'
+ ' field named "unknown_field".'),
+ text_format.Parse, malformed, message,
+ allow_unknown_extension=True)
+
+ # Parse known extension correcty.
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ text = ('message_set {\n'
+ ' [protobuf_unittest.TestMessageSetExtension1] {\n'
+ ' i: 23\n'
+ ' }\n'
+ ' [protobuf_unittest.TestMessageSetExtension2] {\n'
+ ' str: \"foo\"\n'
+ ' }\n'
+ '}\n')
+ text_format.Parse(text, message, allow_unknown_extension=True)
+ ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
+ ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
+ self.assertEqual(23, message.message_set.Extensions[ext1].i)
+ self.assertEqual('foo', message.message_set.Extensions[ext2].str)
+
def testParseBadExtension(self):
message = unittest_pb2.TestAllExtensions()
text = '[unknown_extension]: 8\n'
- six.assertRaisesRegex(self,
+ six.assertRaisesRegex(self,
text_format.ParseError,
'1:2 : Extension "unknown_extension" not registered.',
text_format.Parse, text, message)
message = unittest_pb2.TestAllTypes()
- six.assertRaisesRegex(self,
+ six.assertRaisesRegex(self,
text_format.ParseError,
('1:2 : Message type "protobuf_unittest.TestAllTypes" does not have '
'extensions.'),
@@ -654,7 +800,7 @@ class Proto2Tests(TextFormatBase):
message = unittest_pb2.TestAllExtensions()
text = ('[protobuf_unittest.optional_int32_extension]: 42 '
'[protobuf_unittest.optional_int32_extension]: 67')
- six.assertRaisesRegex(self,
+ six.assertRaisesRegex(self,
text_format.ParseError,
('1:96 : Message type "protobuf_unittest.TestAllExtensions" '
'should not have multiple '
@@ -665,7 +811,7 @@ class Proto2Tests(TextFormatBase):
message = unittest_pb2.TestAllTypes()
text = ('optional_nested_message { bb: 1 } '
'optional_nested_message { bb: 2 }')
- six.assertRaisesRegex(self,
+ six.assertRaisesRegex(self,
text_format.ParseError,
('1:65 : Message type "protobuf_unittest.TestAllTypes.NestedMessage" '
'should not have multiple "bb" fields.'),
@@ -675,7 +821,7 @@ class Proto2Tests(TextFormatBase):
message = unittest_pb2.TestAllTypes()
text = ('optional_int32: 42 '
'optional_int32: 67')
- six.assertRaisesRegex(self,
+ six.assertRaisesRegex(self,
text_format.ParseError,
('1:36 : Message type "protobuf_unittest.TestAllTypes" should not '
'have multiple "optional_int32" fields.'),
@@ -684,11 +830,11 @@ class Proto2Tests(TextFormatBase):
def testParseGroupNotClosed(self):
message = unittest_pb2.TestAllTypes()
text = 'RepeatedGroup: <'
- six.assertRaisesRegex(self,
+ six.assertRaisesRegex(self,
text_format.ParseError, '1:16 : Expected ">".',
text_format.Parse, text, message)
text = 'RepeatedGroup: {'
- six.assertRaisesRegex(self,
+ six.assertRaisesRegex(self,
text_format.ParseError, '1:16 : Expected "}".',
text_format.Parse, text, message)
diff --git a/python/google/protobuf/internal/well_known_types.py b/python/google/protobuf/internal/well_known_types.py
new file mode 100644
index 00000000..d3de9831
--- /dev/null
+++ b/python/google/protobuf/internal/well_known_types.py
@@ -0,0 +1,622 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# https://developers.google.com/protocol-buffers/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Contains well known classes.
+
+This files defines well known classes which need extra maintenance including:
+ - Any
+ - Duration
+ - FieldMask
+ - Timestamp
+"""
+
+__author__ = 'jieluo@google.com (Jie Luo)'
+
+from datetime import datetime
+from datetime import timedelta
+
+from google.protobuf.descriptor import FieldDescriptor
+
+_TIMESTAMPFOMAT = '%Y-%m-%dT%H:%M:%S'
+_NANOS_PER_SECOND = 1000000000
+_NANOS_PER_MILLISECOND = 1000000
+_NANOS_PER_MICROSECOND = 1000
+_MILLIS_PER_SECOND = 1000
+_MICROS_PER_SECOND = 1000000
+_SECONDS_PER_DAY = 24 * 3600
+
+
+class Error(Exception):
+ """Top-level module error."""
+
+
+class ParseError(Error):
+ """Thrown in case of parsing error."""
+
+
+class Any(object):
+ """Class for Any Message type."""
+
+ def Pack(self, msg):
+ """Packs the specified message into current Any message."""
+ self.type_url = 'type.googleapis.com/%s' % msg.DESCRIPTOR.full_name
+ self.value = msg.SerializeToString()
+
+ def Unpack(self, msg):
+ """Unpacks the current Any message into specified message."""
+ descriptor = msg.DESCRIPTOR
+ if not self.Is(descriptor):
+ return False
+ msg.ParseFromString(self.value)
+ return True
+
+ def Is(self, descriptor):
+ """Checks if this Any represents the given protobuf type."""
+ # Only last part is to be used: b/25630112
+ return self.type_url.split('/')[-1] == descriptor.full_name
+
+
+class Timestamp(object):
+ """Class for Timestamp message type."""
+
+ def ToJsonString(self):
+ """Converts Timestamp to RFC 3339 date string format.
+
+ Returns:
+ A string converted from timestamp. The string is always Z-normalized
+ and uses 3, 6 or 9 fractional digits as required to represent the
+ exact time. Example of the return format: '1972-01-01T10:00:20.021Z'
+ """
+ nanos = self.nanos % _NANOS_PER_SECOND
+ total_sec = self.seconds + (self.nanos - nanos) // _NANOS_PER_SECOND
+ seconds = total_sec % _SECONDS_PER_DAY
+ days = (total_sec - seconds) // _SECONDS_PER_DAY
+ dt = datetime(1970, 1, 1) + timedelta(days, seconds)
+
+ result = dt.isoformat()
+ if (nanos % 1e9) == 0:
+ # If there are 0 fractional digits, the fractional
+ # point '.' should be omitted when serializing.
+ return result + 'Z'
+ if (nanos % 1e6) == 0:
+ # Serialize 3 fractional digits.
+ return result + '.%03dZ' % (nanos / 1e6)
+ if (nanos % 1e3) == 0:
+ # Serialize 6 fractional digits.
+ return result + '.%06dZ' % (nanos / 1e3)
+ # Serialize 9 fractional digits.
+ return result + '.%09dZ' % nanos
+
+ def FromJsonString(self, value):
+ """Parse a RFC 3339 date string format to Timestamp.
+
+ Args:
+ value: A date string. Any fractional digits (or none) and any offset are
+ accepted as long as they fit into nano-seconds precision.
+ Example of accepted format: '1972-01-01T10:00:20.021-05:00'
+
+ Raises:
+ ParseError: On parsing problems.
+ """
+ timezone_offset = value.find('Z')
+ if timezone_offset == -1:
+ timezone_offset = value.find('+')
+ if timezone_offset == -1:
+ timezone_offset = value.rfind('-')
+ if timezone_offset == -1:
+ raise ParseError(
+ 'Failed to parse timestamp: missing valid timezone offset.')
+ time_value = value[0:timezone_offset]
+ # Parse datetime and nanos.
+ point_position = time_value.find('.')
+ if point_position == -1:
+ second_value = time_value
+ nano_value = ''
+ else:
+ second_value = time_value[:point_position]
+ nano_value = time_value[point_position + 1:]
+ date_object = datetime.strptime(second_value, _TIMESTAMPFOMAT)
+ td = date_object - datetime(1970, 1, 1)
+ seconds = td.seconds + td.days * _SECONDS_PER_DAY
+ if len(nano_value) > 9:
+ raise ParseError(
+ 'Failed to parse Timestamp: nanos {0} more than '
+ '9 fractional digits.'.format(nano_value))
+ if nano_value:
+ nanos = round(float('0.' + nano_value) * 1e9)
+ else:
+ nanos = 0
+ # Parse timezone offsets.
+ if value[timezone_offset] == 'Z':
+ if len(value) != timezone_offset + 1:
+ raise ParseError('Failed to parse timestamp: invalid trailing'
+ ' data {0}.'.format(value))
+ else:
+ timezone = value[timezone_offset:]
+ pos = timezone.find(':')
+ if pos == -1:
+ raise ParseError(
+ 'Invalid timezone offset value: {0}.'.format(timezone))
+ if timezone[0] == '+':
+ seconds -= (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60
+ else:
+ seconds += (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60
+ # Set seconds and nanos
+ self.seconds = int(seconds)
+ self.nanos = int(nanos)
+
+ def GetCurrentTime(self):
+ """Get the current UTC into Timestamp."""
+ self.FromDatetime(datetime.utcnow())
+
+ def ToNanoseconds(self):
+ """Converts Timestamp to nanoseconds since epoch."""
+ return self.seconds * _NANOS_PER_SECOND + self.nanos
+
+ def ToMicroseconds(self):
+ """Converts Timestamp to microseconds since epoch."""
+ return (self.seconds * _MICROS_PER_SECOND +
+ self.nanos // _NANOS_PER_MICROSECOND)
+
+ def ToMilliseconds(self):
+ """Converts Timestamp to milliseconds since epoch."""
+ return (self.seconds * _MILLIS_PER_SECOND +
+ self.nanos // _NANOS_PER_MILLISECOND)
+
+ def ToSeconds(self):
+ """Converts Timestamp to seconds since epoch."""
+ return self.seconds
+
+ def FromNanoseconds(self, nanos):
+ """Converts nanoseconds since epoch to Timestamp."""
+ self.seconds = nanos // _NANOS_PER_SECOND
+ self.nanos = nanos % _NANOS_PER_SECOND
+
+ def FromMicroseconds(self, micros):
+ """Converts microseconds since epoch to Timestamp."""
+ self.seconds = micros // _MICROS_PER_SECOND
+ self.nanos = (micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND
+
+ def FromMilliseconds(self, millis):
+ """Converts milliseconds since epoch to Timestamp."""
+ self.seconds = millis // _MILLIS_PER_SECOND
+ self.nanos = (millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND
+
+ def FromSeconds(self, seconds):
+ """Converts seconds since epoch to Timestamp."""
+ self.seconds = seconds
+ self.nanos = 0
+
+ def ToDatetime(self):
+ """Converts Timestamp to datetime."""
+ return datetime.utcfromtimestamp(
+ self.seconds + self.nanos / float(_NANOS_PER_SECOND))
+
+ def FromDatetime(self, dt):
+ """Converts datetime to Timestamp."""
+ td = dt - datetime(1970, 1, 1)
+ self.seconds = td.seconds + td.days * _SECONDS_PER_DAY
+ self.nanos = td.microseconds * _NANOS_PER_MICROSECOND
+
+
+class Duration(object):
+ """Class for Duration message type."""
+
+ def ToJsonString(self):
+ """Converts Duration to string format.
+
+ Returns:
+ A string converted from self. The string format will contains
+ 3, 6, or 9 fractional digits depending on the precision required to
+ represent the exact Duration value. For example: "1s", "1.010s",
+ "1.000000100s", "-3.100s"
+ """
+ if self.seconds < 0 or self.nanos < 0:
+ result = '-'
+ seconds = - self.seconds + int((0 - self.nanos) // 1e9)
+ nanos = (0 - self.nanos) % 1e9
+ else:
+ result = ''
+ seconds = self.seconds + int(self.nanos // 1e9)
+ nanos = self.nanos % 1e9
+ result += '%d' % seconds
+ if (nanos % 1e9) == 0:
+ # If there are 0 fractional digits, the fractional
+ # point '.' should be omitted when serializing.
+ return result + 's'
+ if (nanos % 1e6) == 0:
+ # Serialize 3 fractional digits.
+ return result + '.%03ds' % (nanos / 1e6)
+ if (nanos % 1e3) == 0:
+ # Serialize 6 fractional digits.
+ return result + '.%06ds' % (nanos / 1e3)
+ # Serialize 9 fractional digits.
+ return result + '.%09ds' % nanos
+
+ def FromJsonString(self, value):
+ """Converts a string to Duration.
+
+ Args:
+ value: A string to be converted. The string must end with 's'. Any
+ fractional digits (or none) are accepted as long as they fit into
+ precision. For example: "1s", "1.01s", "1.0000001s", "-3.100s
+
+ Raises:
+ ParseError: On parsing problems.
+ """
+ if len(value) < 1 or value[-1] != 's':
+ raise ParseError(
+ 'Duration must end with letter "s": {0}.'.format(value))
+ try:
+ pos = value.find('.')
+ if pos == -1:
+ self.seconds = int(value[:-1])
+ self.nanos = 0
+ else:
+ self.seconds = int(value[:pos])
+ if value[0] == '-':
+ self.nanos = int(round(float('-0{0}'.format(value[pos: -1])) *1e9))
+ else:
+ self.nanos = int(round(float('0{0}'.format(value[pos: -1])) *1e9))
+ except ValueError:
+ raise ParseError(
+ 'Couldn\'t parse duration: {0}.'.format(value))
+
+ def ToNanoseconds(self):
+ """Converts a Duration to nanoseconds."""
+ return self.seconds * _NANOS_PER_SECOND + self.nanos
+
+ def ToMicroseconds(self):
+ """Converts a Duration to microseconds."""
+ micros = _RoundTowardZero(self.nanos, _NANOS_PER_MICROSECOND)
+ return self.seconds * _MICROS_PER_SECOND + micros
+
+ def ToMilliseconds(self):
+ """Converts a Duration to milliseconds."""
+ millis = _RoundTowardZero(self.nanos, _NANOS_PER_MILLISECOND)
+ return self.seconds * _MILLIS_PER_SECOND + millis
+
+ def ToSeconds(self):
+ """Converts a Duration to seconds."""
+ return self.seconds
+
+ def FromNanoseconds(self, nanos):
+ """Converts nanoseconds to Duration."""
+ self._NormalizeDuration(nanos // _NANOS_PER_SECOND,
+ nanos % _NANOS_PER_SECOND)
+
+ def FromMicroseconds(self, micros):
+ """Converts microseconds to Duration."""
+ self._NormalizeDuration(
+ micros // _MICROS_PER_SECOND,
+ (micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND)
+
+ def FromMilliseconds(self, millis):
+ """Converts milliseconds to Duration."""
+ self._NormalizeDuration(
+ millis // _MILLIS_PER_SECOND,
+ (millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND)
+
+ def FromSeconds(self, seconds):
+ """Converts seconds to Duration."""
+ self.seconds = seconds
+ self.nanos = 0
+
+ def ToTimedelta(self):
+ """Converts Duration to timedelta."""
+ return timedelta(
+ seconds=self.seconds, microseconds=_RoundTowardZero(
+ self.nanos, _NANOS_PER_MICROSECOND))
+
+ def FromTimedelta(self, td):
+ """Convertd timedelta to Duration."""
+ self._NormalizeDuration(td.seconds + td.days * _SECONDS_PER_DAY,
+ td.microseconds * _NANOS_PER_MICROSECOND)
+
+ def _NormalizeDuration(self, seconds, nanos):
+ """Set Duration by seconds and nonas."""
+ # Force nanos to be negative if the duration is negative.
+ if seconds < 0 and nanos > 0:
+ seconds += 1
+ nanos -= _NANOS_PER_SECOND
+ self.seconds = seconds
+ self.nanos = nanos
+
+
+def _RoundTowardZero(value, divider):
+ """Truncates the remainder part after division."""
+ # For some languanges, the sign of the remainder is implementation
+ # dependent if any of the operands is negative. Here we enforce
+ # "rounded toward zero" semantics. For example, for (-5) / 2 an
+ # implementation may give -3 as the result with the remainder being
+ # 1. This function ensures we always return -2 (closer to zero).
+ result = value // divider
+ remainder = value % divider
+ if result < 0 and remainder > 0:
+ return result + 1
+ else:
+ return result
+
+
+class FieldMask(object):
+ """Class for FieldMask message type."""
+
+ def ToJsonString(self):
+ """Converts FieldMask to string according to proto3 JSON spec."""
+ return ','.join(self.paths)
+
+ def FromJsonString(self, value):
+ """Converts string to FieldMask according to proto3 JSON spec."""
+ self.Clear()
+ for path in value.split(','):
+ self.paths.append(path)
+
+ def IsValidForDescriptor(self, message_descriptor):
+ """Checks whether the FieldMask is valid for Message Descriptor."""
+ for path in self.paths:
+ if not _IsValidPath(message_descriptor, path):
+ return False
+ return True
+
+ def AllFieldsFromDescriptor(self, message_descriptor):
+ """Gets all direct fields of Message Descriptor to FieldMask."""
+ self.Clear()
+ for field in message_descriptor.fields:
+ self.paths.append(field.name)
+
+ def CanonicalFormFromMask(self, mask):
+ """Converts a FieldMask to the canonical form.
+
+ Removes paths that are covered by another path. For example,
+ "foo.bar" is covered by "foo" and will be removed if "foo"
+ is also in the FieldMask. Then sorts all paths in alphabetical order.
+
+ Args:
+ mask: The original FieldMask to be converted.
+ """
+ tree = _FieldMaskTree(mask)
+ tree.ToFieldMask(self)
+
+ def Union(self, mask1, mask2):
+ """Merges mask1 and mask2 into this FieldMask."""
+ _CheckFieldMaskMessage(mask1)
+ _CheckFieldMaskMessage(mask2)
+ tree = _FieldMaskTree(mask1)
+ tree.MergeFromFieldMask(mask2)
+ tree.ToFieldMask(self)
+
+ def Intersect(self, mask1, mask2):
+ """Intersects mask1 and mask2 into this FieldMask."""
+ _CheckFieldMaskMessage(mask1)
+ _CheckFieldMaskMessage(mask2)
+ tree = _FieldMaskTree(mask1)
+ intersection = _FieldMaskTree()
+ for path in mask2.paths:
+ tree.IntersectPath(path, intersection)
+ intersection.ToFieldMask(self)
+
+ def MergeMessage(
+ self, source, destination,
+ replace_message_field=False, replace_repeated_field=False):
+ """Merges fields specified in FieldMask from source to destination.
+
+ Args:
+ source: Source message.
+ destination: The destination message to be merged into.
+ replace_message_field: Replace message field if True. Merge message
+ field if False.
+ replace_repeated_field: Replace repeated field if True. Append
+ elements of repeated field if False.
+ """
+ tree = _FieldMaskTree(self)
+ tree.MergeMessage(
+ source, destination, replace_message_field, replace_repeated_field)
+
+
+def _IsValidPath(message_descriptor, path):
+ """Checks whether the path is valid for Message Descriptor."""
+ parts = path.split('.')
+ last = parts.pop()
+ for name in parts:
+ field = message_descriptor.fields_by_name[name]
+ if (field is None or
+ field.label == FieldDescriptor.LABEL_REPEATED or
+ field.type != FieldDescriptor.TYPE_MESSAGE):
+ return False
+ message_descriptor = field.message_type
+ return last in message_descriptor.fields_by_name
+
+
+def _CheckFieldMaskMessage(message):
+ """Raises ValueError if message is not a FieldMask."""
+ message_descriptor = message.DESCRIPTOR
+ if (message_descriptor.name != 'FieldMask' or
+ message_descriptor.file.name != 'google/protobuf/field_mask.proto'):
+ raise ValueError('Message {0} is not a FieldMask.'.format(
+ message_descriptor.full_name))
+
+
+class _FieldMaskTree(object):
+ """Represents a FieldMask in a tree structure.
+
+ For example, given a FieldMask "foo.bar,foo.baz,bar.baz",
+ the FieldMaskTree will be:
+ [_root] -+- foo -+- bar
+ | |
+ | +- baz
+ |
+ +- bar --- baz
+ In the tree, each leaf node represents a field path.
+ """
+
+ def __init__(self, field_mask=None):
+ """Initializes the tree by FieldMask."""
+ self._root = {}
+ if field_mask:
+ self.MergeFromFieldMask(field_mask)
+
+ def MergeFromFieldMask(self, field_mask):
+ """Merges a FieldMask to the tree."""
+ for path in field_mask.paths:
+ self.AddPath(path)
+
+ def AddPath(self, path):
+ """Adds a field path into the tree.
+
+ If the field path to add is a sub-path of an existing field path
+ in the tree (i.e., a leaf node), it means the tree already matches
+ the given path so nothing will be added to the tree. If the path
+ matches an existing non-leaf node in the tree, that non-leaf node
+ will be turned into a leaf node with all its children removed because
+ the path matches all the node's children. Otherwise, a new path will
+ be added.
+
+ Args:
+ path: The field path to add.
+ """
+ node = self._root
+ for name in path.split('.'):
+ if name not in node:
+ node[name] = {}
+ elif not node[name]:
+ # Pre-existing empty node implies we already have this entire tree.
+ return
+ node = node[name]
+ # Remove any sub-trees we might have had.
+ node.clear()
+
+ def ToFieldMask(self, field_mask):
+ """Converts the tree to a FieldMask."""
+ field_mask.Clear()
+ _AddFieldPaths(self._root, '', field_mask)
+
+ def IntersectPath(self, path, intersection):
+ """Calculates the intersection part of a field path with this tree.
+
+ Args:
+ path: The field path to calculates.
+ intersection: The out tree to record the intersection part.
+ """
+ node = self._root
+ for name in path.split('.'):
+ if name not in node:
+ return
+ elif not node[name]:
+ intersection.AddPath(path)
+ return
+ node = node[name]
+ intersection.AddLeafNodes(path, node)
+
+ def AddLeafNodes(self, prefix, node):
+ """Adds leaf nodes begin with prefix to this tree."""
+ if not node:
+ self.AddPath(prefix)
+ for name in node:
+ child_path = prefix + '.' + name
+ self.AddLeafNodes(child_path, node[name])
+
+ def MergeMessage(
+ self, source, destination,
+ replace_message, replace_repeated):
+ """Merge all fields specified by this tree from source to destination."""
+ _MergeMessage(
+ self._root, source, destination, replace_message, replace_repeated)
+
+
+def _StrConvert(value):
+ """Converts value to str if it is not."""
+ # This file is imported by c extension and some methods like ClearField
+ # requires string for the field name. py2/py3 has different text
+ # type and may use unicode.
+ if not isinstance(value, str):
+ return value.encode('utf-8')
+ return value
+
+
+def _MergeMessage(
+ node, source, destination, replace_message, replace_repeated):
+ """Merge all fields specified by a sub-tree from source to destination."""
+ source_descriptor = source.DESCRIPTOR
+ for name in node:
+ child = node[name]
+ field = source_descriptor.fields_by_name[name]
+ if field is None:
+ raise ValueError('Error: Can\'t find field {0} in message {1}.'.format(
+ name, source_descriptor.full_name))
+ if child:
+ # Sub-paths are only allowed for singular message fields.
+ if (field.label == FieldDescriptor.LABEL_REPEATED or
+ field.cpp_type != FieldDescriptor.CPPTYPE_MESSAGE):
+ raise ValueError('Error: Field {0} in message {1} is not a singular '
+ 'message field and cannot have sub-fields.'.format(
+ name, source_descriptor.full_name))
+ _MergeMessage(
+ child, getattr(source, name), getattr(destination, name),
+ replace_message, replace_repeated)
+ continue
+ if field.label == FieldDescriptor.LABEL_REPEATED:
+ if replace_repeated:
+ destination.ClearField(_StrConvert(name))
+ repeated_source = getattr(source, name)
+ repeated_destination = getattr(destination, name)
+ if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE:
+ for item in repeated_source:
+ repeated_destination.add().MergeFrom(item)
+ else:
+ repeated_destination.extend(repeated_source)
+ else:
+ if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE:
+ if replace_message:
+ destination.ClearField(_StrConvert(name))
+ if source.HasField(name):
+ getattr(destination, name).MergeFrom(getattr(source, name))
+ else:
+ setattr(destination, name, getattr(source, name))
+
+
+def _AddFieldPaths(node, prefix, field_mask):
+ """Adds the field paths descended from node to field_mask."""
+ if not node:
+ field_mask.paths.append(prefix)
+ return
+ for name in sorted(node):
+ if prefix:
+ child_path = prefix + '.' + name
+ else:
+ child_path = name
+ _AddFieldPaths(node[name], child_path, field_mask)
+
+
+WKTBASES = {
+ 'google.protobuf.Any': Any,
+ 'google.protobuf.Duration': Duration,
+ 'google.protobuf.FieldMask': FieldMask,
+ 'google.protobuf.Timestamp': Timestamp,
+}
diff --git a/python/google/protobuf/internal/well_known_types_test.py b/python/google/protobuf/internal/well_known_types_test.py
new file mode 100644
index 00000000..60b0c47d
--- /dev/null
+++ b/python/google/protobuf/internal/well_known_types_test.py
@@ -0,0 +1,509 @@
+#! /usr/bin/env python
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# https://developers.google.com/protocol-buffers/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Test for google.protobuf.internal.well_known_types."""
+
+__author__ = 'jieluo@google.com (Jie Luo)'
+
+from datetime import datetime
+
+from google.protobuf import duration_pb2
+from google.protobuf import field_mask_pb2
+from google.protobuf import timestamp_pb2
+import unittest
+from google.protobuf import unittest_pb2
+from google.protobuf.internal import test_util
+from google.protobuf.internal import well_known_types
+from google.protobuf import descriptor
+
+
+class TimeUtilTestBase(unittest.TestCase):
+
+ def CheckTimestampConversion(self, message, text):
+ self.assertEqual(text, message.ToJsonString())
+ parsed_message = timestamp_pb2.Timestamp()
+ parsed_message.FromJsonString(text)
+ self.assertEqual(message, parsed_message)
+
+ def CheckDurationConversion(self, message, text):
+ self.assertEqual(text, message.ToJsonString())
+ parsed_message = duration_pb2.Duration()
+ parsed_message.FromJsonString(text)
+ self.assertEqual(message, parsed_message)
+
+
+class TimeUtilTest(TimeUtilTestBase):
+
+ def testTimestampSerializeAndParse(self):
+ message = timestamp_pb2.Timestamp()
+ # Generated output should contain 3, 6, or 9 fractional digits.
+ message.seconds = 0
+ message.nanos = 0
+ self.CheckTimestampConversion(message, '1970-01-01T00:00:00Z')
+ message.nanos = 10000000
+ self.CheckTimestampConversion(message, '1970-01-01T00:00:00.010Z')
+ message.nanos = 10000
+ self.CheckTimestampConversion(message, '1970-01-01T00:00:00.000010Z')
+ message.nanos = 10
+ self.CheckTimestampConversion(message, '1970-01-01T00:00:00.000000010Z')
+ # Test min timestamps.
+ message.seconds = -62135596800
+ message.nanos = 0
+ self.CheckTimestampConversion(message, '0001-01-01T00:00:00Z')
+ # Test max timestamps.
+ message.seconds = 253402300799
+ message.nanos = 999999999
+ self.CheckTimestampConversion(message, '9999-12-31T23:59:59.999999999Z')
+ # Test negative timestamps.
+ message.seconds = -1
+ self.CheckTimestampConversion(message, '1969-12-31T23:59:59.999999999Z')
+
+ # Parsing accepts an fractional digits as long as they fit into nano
+ # precision.
+ message.FromJsonString('1970-01-01T00:00:00.1Z')
+ self.assertEqual(0, message.seconds)
+ self.assertEqual(100000000, message.nanos)
+ # Parsing accpets offsets.
+ message.FromJsonString('1970-01-01T00:00:00-08:00')
+ self.assertEqual(8 * 3600, message.seconds)
+ self.assertEqual(0, message.nanos)
+
+ def testDurationSerializeAndParse(self):
+ message = duration_pb2.Duration()
+ # Generated output should contain 3, 6, or 9 fractional digits.
+ message.seconds = 0
+ message.nanos = 0
+ self.CheckDurationConversion(message, '0s')
+ message.nanos = 10000000
+ self.CheckDurationConversion(message, '0.010s')
+ message.nanos = 10000
+ self.CheckDurationConversion(message, '0.000010s')
+ message.nanos = 10
+ self.CheckDurationConversion(message, '0.000000010s')
+
+ # Test min and max
+ message.seconds = 315576000000
+ message.nanos = 999999999
+ self.CheckDurationConversion(message, '315576000000.999999999s')
+ message.seconds = -315576000000
+ message.nanos = -999999999
+ self.CheckDurationConversion(message, '-315576000000.999999999s')
+
+ # Parsing accepts an fractional digits as long as they fit into nano
+ # precision.
+ message.FromJsonString('0.1s')
+ self.assertEqual(100000000, message.nanos)
+ message.FromJsonString('0.0000001s')
+ self.assertEqual(100, message.nanos)
+
+ def testTimestampIntegerConversion(self):
+ message = timestamp_pb2.Timestamp()
+ message.FromNanoseconds(1)
+ self.assertEqual('1970-01-01T00:00:00.000000001Z',
+ message.ToJsonString())
+ self.assertEqual(1, message.ToNanoseconds())
+
+ message.FromNanoseconds(-1)
+ self.assertEqual('1969-12-31T23:59:59.999999999Z',
+ message.ToJsonString())
+ self.assertEqual(-1, message.ToNanoseconds())
+
+ message.FromMicroseconds(1)
+ self.assertEqual('1970-01-01T00:00:00.000001Z',
+ message.ToJsonString())
+ self.assertEqual(1, message.ToMicroseconds())
+
+ message.FromMicroseconds(-1)
+ self.assertEqual('1969-12-31T23:59:59.999999Z',
+ message.ToJsonString())
+ self.assertEqual(-1, message.ToMicroseconds())
+
+ message.FromMilliseconds(1)
+ self.assertEqual('1970-01-01T00:00:00.001Z',
+ message.ToJsonString())
+ self.assertEqual(1, message.ToMilliseconds())
+
+ message.FromMilliseconds(-1)
+ self.assertEqual('1969-12-31T23:59:59.999Z',
+ message.ToJsonString())
+ self.assertEqual(-1, message.ToMilliseconds())
+
+ message.FromSeconds(1)
+ self.assertEqual('1970-01-01T00:00:01Z',
+ message.ToJsonString())
+ self.assertEqual(1, message.ToSeconds())
+
+ message.FromSeconds(-1)
+ self.assertEqual('1969-12-31T23:59:59Z',
+ message.ToJsonString())
+ self.assertEqual(-1, message.ToSeconds())
+
+ message.FromNanoseconds(1999)
+ self.assertEqual(1, message.ToMicroseconds())
+ # For negative values, Timestamp will be rounded down.
+ # For example, "1969-12-31T23:59:59.5Z" (i.e., -0.5s) rounded to seconds
+ # will be "1969-12-31T23:59:59Z" (i.e., -1s) rather than
+ # "1970-01-01T00:00:00Z" (i.e., 0s).
+ message.FromNanoseconds(-1999)
+ self.assertEqual(-2, message.ToMicroseconds())
+
+ def testDurationIntegerConversion(self):
+ message = duration_pb2.Duration()
+ message.FromNanoseconds(1)
+ self.assertEqual('0.000000001s',
+ message.ToJsonString())
+ self.assertEqual(1, message.ToNanoseconds())
+
+ message.FromNanoseconds(-1)
+ self.assertEqual('-0.000000001s',
+ message.ToJsonString())
+ self.assertEqual(-1, message.ToNanoseconds())
+
+ message.FromMicroseconds(1)
+ self.assertEqual('0.000001s',
+ message.ToJsonString())
+ self.assertEqual(1, message.ToMicroseconds())
+
+ message.FromMicroseconds(-1)
+ self.assertEqual('-0.000001s',
+ message.ToJsonString())
+ self.assertEqual(-1, message.ToMicroseconds())
+
+ message.FromMilliseconds(1)
+ self.assertEqual('0.001s',
+ message.ToJsonString())
+ self.assertEqual(1, message.ToMilliseconds())
+
+ message.FromMilliseconds(-1)
+ self.assertEqual('-0.001s',
+ message.ToJsonString())
+ self.assertEqual(-1, message.ToMilliseconds())
+
+ message.FromSeconds(1)
+ self.assertEqual('1s', message.ToJsonString())
+ self.assertEqual(1, message.ToSeconds())
+
+ message.FromSeconds(-1)
+ self.assertEqual('-1s',
+ message.ToJsonString())
+ self.assertEqual(-1, message.ToSeconds())
+
+ # Test truncation behavior.
+ message.FromNanoseconds(1999)
+ self.assertEqual(1, message.ToMicroseconds())
+
+ # For negative values, Duration will be rounded towards 0.
+ message.FromNanoseconds(-1999)
+ self.assertEqual(-1, message.ToMicroseconds())
+
+ def testDatetimeConverison(self):
+ message = timestamp_pb2.Timestamp()
+ dt = datetime(1970, 1, 1)
+ message.FromDatetime(dt)
+ self.assertEqual(dt, message.ToDatetime())
+
+ message.FromMilliseconds(1999)
+ self.assertEqual(datetime(1970, 1, 1, 0, 0, 1, 999000),
+ message.ToDatetime())
+
+ def testTimedeltaConversion(self):
+ message = duration_pb2.Duration()
+ message.FromNanoseconds(1999999999)
+ td = message.ToTimedelta()
+ self.assertEqual(1, td.seconds)
+ self.assertEqual(999999, td.microseconds)
+
+ message.FromNanoseconds(-1999999999)
+ td = message.ToTimedelta()
+ self.assertEqual(-1, td.days)
+ self.assertEqual(86398, td.seconds)
+ self.assertEqual(1, td.microseconds)
+
+ message.FromMicroseconds(-1)
+ td = message.ToTimedelta()
+ self.assertEqual(-1, td.days)
+ self.assertEqual(86399, td.seconds)
+ self.assertEqual(999999, td.microseconds)
+ converted_message = duration_pb2.Duration()
+ converted_message.FromTimedelta(td)
+ self.assertEqual(message, converted_message)
+
+ def testInvalidTimestamp(self):
+ message = timestamp_pb2.Timestamp()
+ self.assertRaisesRegexp(
+ ValueError,
+ 'time data \'10000-01-01T00:00:00\' does not match'
+ ' format \'%Y-%m-%dT%H:%M:%S\'',
+ message.FromJsonString, '10000-01-01T00:00:00.00Z')
+ self.assertRaisesRegexp(
+ well_known_types.ParseError,
+ 'nanos 0123456789012 more than 9 fractional digits.',
+ message.FromJsonString,
+ '1970-01-01T00:00:00.0123456789012Z')
+ self.assertRaisesRegexp(
+ well_known_types.ParseError,
+ (r'Invalid timezone offset value: \+08.'),
+ message.FromJsonString,
+ '1972-01-01T01:00:00.01+08',)
+ self.assertRaisesRegexp(
+ ValueError,
+ 'year is out of range',
+ message.FromJsonString,
+ '0000-01-01T00:00:00Z')
+ message.seconds = 253402300800
+ self.assertRaisesRegexp(
+ OverflowError,
+ 'date value out of range',
+ message.ToJsonString)
+
+ def testInvalidDuration(self):
+ message = duration_pb2.Duration()
+ self.assertRaisesRegexp(
+ well_known_types.ParseError,
+ 'Duration must end with letter "s": 1.',
+ message.FromJsonString, '1')
+ self.assertRaisesRegexp(
+ well_known_types.ParseError,
+ 'Couldn\'t parse duration: 1...2s.',
+ message.FromJsonString, '1...2s')
+
+
+class FieldMaskTest(unittest.TestCase):
+
+ def testStringFormat(self):
+ mask = field_mask_pb2.FieldMask()
+ self.assertEqual('', mask.ToJsonString())
+ mask.paths.append('foo')
+ self.assertEqual('foo', mask.ToJsonString())
+ mask.paths.append('bar')
+ self.assertEqual('foo,bar', mask.ToJsonString())
+
+ mask.FromJsonString('')
+ self.assertEqual('', mask.ToJsonString())
+ mask.FromJsonString('foo')
+ self.assertEqual(['foo'], mask.paths)
+ mask.FromJsonString('foo,bar')
+ self.assertEqual(['foo', 'bar'], mask.paths)
+
+ def testDescriptorToFieldMask(self):
+ mask = field_mask_pb2.FieldMask()
+ msg_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
+ mask.AllFieldsFromDescriptor(msg_descriptor)
+ self.assertEqual(75, len(mask.paths))
+ self.assertTrue(mask.IsValidForDescriptor(msg_descriptor))
+ for field in msg_descriptor.fields:
+ self.assertTrue(field.name in mask.paths)
+ mask.paths.append('optional_nested_message.bb')
+ self.assertTrue(mask.IsValidForDescriptor(msg_descriptor))
+ mask.paths.append('repeated_nested_message.bb')
+ self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
+
+ def testCanonicalFrom(self):
+ mask = field_mask_pb2.FieldMask()
+ out_mask = field_mask_pb2.FieldMask()
+ # Paths will be sorted.
+ mask.FromJsonString('baz.quz,bar,foo')
+ out_mask.CanonicalFormFromMask(mask)
+ self.assertEqual('bar,baz.quz,foo', out_mask.ToJsonString())
+ # Duplicated paths will be removed.
+ mask.FromJsonString('foo,bar,foo')
+ out_mask.CanonicalFormFromMask(mask)
+ self.assertEqual('bar,foo', out_mask.ToJsonString())
+ # Sub-paths of other paths will be removed.
+ mask.FromJsonString('foo.b1,bar.b1,foo.b2,bar')
+ out_mask.CanonicalFormFromMask(mask)
+ self.assertEqual('bar,foo.b1,foo.b2', out_mask.ToJsonString())
+
+ # Test more deeply nested cases.
+ mask.FromJsonString(
+ 'foo.bar.baz1,foo.bar.baz2.quz,foo.bar.baz2')
+ out_mask.CanonicalFormFromMask(mask)
+ self.assertEqual('foo.bar.baz1,foo.bar.baz2',
+ out_mask.ToJsonString())
+ mask.FromJsonString(
+ 'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz')
+ out_mask.CanonicalFormFromMask(mask)
+ self.assertEqual('foo.bar.baz1,foo.bar.baz2',
+ out_mask.ToJsonString())
+ mask.FromJsonString(
+ 'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz,foo.bar')
+ out_mask.CanonicalFormFromMask(mask)
+ self.assertEqual('foo.bar', out_mask.ToJsonString())
+ mask.FromJsonString(
+ 'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz,foo')
+ out_mask.CanonicalFormFromMask(mask)
+ self.assertEqual('foo', out_mask.ToJsonString())
+
+ def testUnion(self):
+ mask1 = field_mask_pb2.FieldMask()
+ mask2 = field_mask_pb2.FieldMask()
+ out_mask = field_mask_pb2.FieldMask()
+ mask1.FromJsonString('foo,baz')
+ mask2.FromJsonString('bar,quz')
+ out_mask.Union(mask1, mask2)
+ self.assertEqual('bar,baz,foo,quz', out_mask.ToJsonString())
+ # Overlap with duplicated paths.
+ mask1.FromJsonString('foo,baz.bb')
+ mask2.FromJsonString('baz.bb,quz')
+ out_mask.Union(mask1, mask2)
+ self.assertEqual('baz.bb,foo,quz', out_mask.ToJsonString())
+ # Overlap with paths covering some other paths.
+ mask1.FromJsonString('foo.bar.baz,quz')
+ mask2.FromJsonString('foo.bar,bar')
+ out_mask.Union(mask1, mask2)
+ self.assertEqual('bar,foo.bar,quz', out_mask.ToJsonString())
+
+ def testIntersect(self):
+ mask1 = field_mask_pb2.FieldMask()
+ mask2 = field_mask_pb2.FieldMask()
+ out_mask = field_mask_pb2.FieldMask()
+ # Test cases without overlapping.
+ mask1.FromJsonString('foo,baz')
+ mask2.FromJsonString('bar,quz')
+ out_mask.Intersect(mask1, mask2)
+ self.assertEqual('', out_mask.ToJsonString())
+ # Overlap with duplicated paths.
+ mask1.FromJsonString('foo,baz.bb')
+ mask2.FromJsonString('baz.bb,quz')
+ out_mask.Intersect(mask1, mask2)
+ self.assertEqual('baz.bb', out_mask.ToJsonString())
+ # Overlap with paths covering some other paths.
+ mask1.FromJsonString('foo.bar.baz,quz')
+ mask2.FromJsonString('foo.bar,bar')
+ out_mask.Intersect(mask1, mask2)
+ self.assertEqual('foo.bar.baz', out_mask.ToJsonString())
+ mask1.FromJsonString('foo.bar,bar')
+ mask2.FromJsonString('foo.bar.baz,quz')
+ out_mask.Intersect(mask1, mask2)
+ self.assertEqual('foo.bar.baz', out_mask.ToJsonString())
+
+ def testMergeMessage(self):
+ # Test merge one field.
+ src = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(src)
+ for field in src.DESCRIPTOR.fields:
+ if field.containing_oneof:
+ continue
+ field_name = field.name
+ dst = unittest_pb2.TestAllTypes()
+ # Only set one path to mask.
+ mask = field_mask_pb2.FieldMask()
+ mask.paths.append(field_name)
+ mask.MergeMessage(src, dst)
+ # The expected result message.
+ msg = unittest_pb2.TestAllTypes()
+ if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ repeated_src = getattr(src, field_name)
+ repeated_msg = getattr(msg, field_name)
+ if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ for item in repeated_src:
+ repeated_msg.add().CopyFrom(item)
+ else:
+ repeated_msg.extend(repeated_src)
+ elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ getattr(msg, field_name).CopyFrom(getattr(src, field_name))
+ else:
+ setattr(msg, field_name, getattr(src, field_name))
+ # Only field specified in mask is merged.
+ self.assertEqual(msg, dst)
+
+ # Test merge nested fields.
+ nested_src = unittest_pb2.NestedTestAllTypes()
+ nested_dst = unittest_pb2.NestedTestAllTypes()
+ nested_src.child.payload.optional_int32 = 1234
+ nested_src.child.child.payload.optional_int32 = 5678
+ mask = field_mask_pb2.FieldMask()
+ mask.FromJsonString('child.payload')
+ mask.MergeMessage(nested_src, nested_dst)
+ self.assertEqual(1234, nested_dst.child.payload.optional_int32)
+ self.assertEqual(0, nested_dst.child.child.payload.optional_int32)
+
+ mask.FromJsonString('child.child.payload')
+ mask.MergeMessage(nested_src, nested_dst)
+ self.assertEqual(1234, nested_dst.child.payload.optional_int32)
+ self.assertEqual(5678, nested_dst.child.child.payload.optional_int32)
+
+ nested_dst.Clear()
+ mask.FromJsonString('child.child.payload')
+ mask.MergeMessage(nested_src, nested_dst)
+ self.assertEqual(0, nested_dst.child.payload.optional_int32)
+ self.assertEqual(5678, nested_dst.child.child.payload.optional_int32)
+
+ nested_dst.Clear()
+ mask.FromJsonString('child')
+ mask.MergeMessage(nested_src, nested_dst)
+ self.assertEqual(1234, nested_dst.child.payload.optional_int32)
+ self.assertEqual(5678, nested_dst.child.child.payload.optional_int32)
+
+ # Test MergeOptions.
+ nested_dst.Clear()
+ nested_dst.child.payload.optional_int64 = 4321
+ # Message fields will be merged by default.
+ mask.FromJsonString('child.payload')
+ mask.MergeMessage(nested_src, nested_dst)
+ self.assertEqual(1234, nested_dst.child.payload.optional_int32)
+ self.assertEqual(4321, nested_dst.child.payload.optional_int64)
+ # Change the behavior to replace message fields.
+ mask.FromJsonString('child.payload')
+ mask.MergeMessage(nested_src, nested_dst, True, False)
+ self.assertEqual(1234, nested_dst.child.payload.optional_int32)
+ self.assertEqual(0, nested_dst.child.payload.optional_int64)
+
+ # By default, fields missing in source are not cleared in destination.
+ nested_dst.payload.optional_int32 = 1234
+ self.assertTrue(nested_dst.HasField('payload'))
+ mask.FromJsonString('payload')
+ mask.MergeMessage(nested_src, nested_dst)
+ self.assertTrue(nested_dst.HasField('payload'))
+ # But they are cleared when replacing message fields.
+ nested_dst.Clear()
+ nested_dst.payload.optional_int32 = 1234
+ mask.FromJsonString('payload')
+ mask.MergeMessage(nested_src, nested_dst, True, False)
+ self.assertFalse(nested_dst.HasField('payload'))
+
+ nested_src.payload.repeated_int32.append(1234)
+ nested_dst.payload.repeated_int32.append(5678)
+ # Repeated fields will be appended by default.
+ mask.FromJsonString('payload.repeated_int32')
+ mask.MergeMessage(nested_src, nested_dst)
+ self.assertEqual(2, len(nested_dst.payload.repeated_int32))
+ self.assertEqual(5678, nested_dst.payload.repeated_int32[0])
+ self.assertEqual(1234, nested_dst.payload.repeated_int32[1])
+ # Change the behavior to replace repeated fields.
+ mask.FromJsonString('payload.repeated_int32')
+ mask.MergeMessage(nested_src, nested_dst, False, True)
+ self.assertEqual(1, len(nested_dst.payload.repeated_int32))
+ self.assertEqual(1234, nested_dst.payload.repeated_int32[0])
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/python/google/protobuf/json_format.py b/python/google/protobuf/json_format.py
index d95557d7..cb76e116 100644
--- a/python/google/protobuf/json_format.py
+++ b/python/google/protobuf/json_format.py
@@ -28,22 +28,29 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-"""Contains routines for printing protocol messages in JSON format."""
+"""Contains routines for printing protocol messages in JSON format.
+
+Simple usage example:
+
+ # Create a proto object and serialize it to a json format string.
+ message = my_proto_pb2.MyMessage(foo='bar')
+ json_string = json_format.MessageToJson(message)
+
+ # Parse a json format string to proto object.
+ message = json_format.Parse(json_string, my_proto_pb2.MyMessage())
+"""
__author__ = 'jieluo@google.com (Jie Luo)'
import base64
-from datetime import datetime
import json
import math
-import re
+from six import text_type
import sys
from google.protobuf import descriptor
_TIMESTAMPFOMAT = '%Y-%m-%dT%H:%M:%S'
-_NUMBER = re.compile(u'[0-9+-][0-9e.+-]*')
-_INTEGER = re.compile(u'[0-9+-]')
_INT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_INT32,
descriptor.FieldDescriptor.CPPTYPE_UINT32,
descriptor.FieldDescriptor.CPPTYPE_INT64,
@@ -52,17 +59,20 @@ _INT64_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_INT64,
descriptor.FieldDescriptor.CPPTYPE_UINT64])
_FLOAT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_FLOAT,
descriptor.FieldDescriptor.CPPTYPE_DOUBLE])
-if str is bytes:
- _UNICODETYPE = unicode
-else:
- _UNICODETYPE = str
+_INFINITY = 'Infinity'
+_NEG_INFINITY = '-Infinity'
+_NAN = 'NaN'
+
+class Error(Exception):
+ """Top-level module error for json_format."""
-class SerializeToJsonError(Exception):
+
+class SerializeToJsonError(Error):
"""Thrown if serialization to JSON fails."""
-class ParseError(Exception):
+class ParseError(Error):
"""Thrown in case of parsing error."""
@@ -86,12 +96,8 @@ def MessageToJson(message, including_default_value_fields=False):
def _MessageToJsonObject(message, including_default_value_fields):
"""Converts message to an object according to Proto3 JSON Specification."""
message_descriptor = message.DESCRIPTOR
- if _IsTimestampMessage(message_descriptor):
- return _TimestampMessageToJsonObject(message)
- if _IsDurationMessage(message_descriptor):
- return _DurationMessageToJsonObject(message)
- if _IsFieldMaskMessage(message_descriptor):
- return _FieldMaskMessageToJsonObject(message)
+ if hasattr(message, 'ToJsonString'):
+ return message.ToJsonString()
if _IsWrapperMessage(message_descriptor):
return _WrapperMessageToJsonObject(message)
return _RegularMessageToJsonObject(message, including_default_value_fields)
@@ -107,12 +113,14 @@ def _RegularMessageToJsonObject(message, including_default_value_fields):
"""Converts normal message according to Proto3 JSON Specification."""
js = {}
fields = message.ListFields()
+ include_default = including_default_value_fields
try:
for field, value in fields:
name = field.camelcase_name
if _IsMapEntry(field):
# Convert a map field.
+ v_field = field.message_type.fields_by_name['value']
js_map = {}
for key in value:
if isinstance(key, bool):
@@ -122,20 +130,15 @@ def _RegularMessageToJsonObject(message, including_default_value_fields):
recorded_key = 'false'
else:
recorded_key = key
- js_map[recorded_key] = _ConvertFieldToJsonObject(
- field.message_type.fields_by_name['value'],
- value[key], including_default_value_fields)
+ js_map[recorded_key] = _FieldToJsonObject(
+ v_field, value[key], including_default_value_fields)
js[name] = js_map
elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
# Convert a repeated field.
- repeated = []
- for element in value:
- repeated.append(_ConvertFieldToJsonObject(
- field, element, including_default_value_fields))
- js[name] = repeated
+ js[name] = [_FieldToJsonObject(field, k, include_default)
+ for k in value]
else:
- js[name] = _ConvertFieldToJsonObject(
- field, value, including_default_value_fields)
+ js[name] = _FieldToJsonObject(field, value, include_default)
# Serialize default value if including_default_value_fields is True.
if including_default_value_fields:
@@ -155,16 +158,16 @@ def _RegularMessageToJsonObject(message, including_default_value_fields):
elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
js[name] = []
else:
- js[name] = _ConvertFieldToJsonObject(field, field.default_value)
+ js[name] = _FieldToJsonObject(field, field.default_value)
except ValueError as e:
raise SerializeToJsonError(
- 'Failed to serialize {0} field: {1}'.format(field.name, e))
+ 'Failed to serialize {0} field: {1}.'.format(field.name, e))
return js
-def _ConvertFieldToJsonObject(
+def _FieldToJsonObject(
field, value, including_default_value_fields=False):
"""Converts field value according to Proto3 JSON Specification."""
if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
@@ -183,101 +186,26 @@ def _ConvertFieldToJsonObject(
else:
return value
elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL:
- if value:
- return True
- else:
- return False
+ return bool(value)
elif field.cpp_type in _INT64_TYPES:
return str(value)
elif field.cpp_type in _FLOAT_TYPES:
if math.isinf(value):
if value < 0.0:
- return '-Infinity'
+ return _NEG_INFINITY
else:
- return 'Infinity'
+ return _INFINITY
if math.isnan(value):
- return 'NaN'
+ return _NAN
return value
-def _IsTimestampMessage(message_descriptor):
- return (message_descriptor.name == 'Timestamp' and
- message_descriptor.file.name == 'google/protobuf/timestamp.proto')
-
-
-def _TimestampMessageToJsonObject(message):
- """Converts Timestamp message according to Proto3 JSON Specification."""
- nanos = message.nanos % 1e9
- dt = datetime.utcfromtimestamp(
- message.seconds + (message.nanos - nanos) / 1e9)
- result = dt.isoformat()
- if (nanos % 1e9) == 0:
- # If there are 0 fractional digits, the fractional
- # point '.' should be omitted when serializing.
- return result + 'Z'
- if (nanos % 1e6) == 0:
- # Serialize 3 fractional digits.
- return result + '.%03dZ' % (nanos / 1e6)
- if (nanos % 1e3) == 0:
- # Serialize 6 fractional digits.
- return result + '.%06dZ' % (nanos / 1e3)
- # Serialize 9 fractional digits.
- return result + '.%09dZ' % nanos
-
-
-def _IsDurationMessage(message_descriptor):
- return (message_descriptor.name == 'Duration' and
- message_descriptor.file.name == 'google/protobuf/duration.proto')
-
-
-def _DurationMessageToJsonObject(message):
- """Converts Duration message according to Proto3 JSON Specification."""
- if message.seconds < 0 or message.nanos < 0:
- result = '-'
- seconds = - message.seconds + int((0 - message.nanos) / 1e9)
- nanos = (0 - message.nanos) % 1e9
- else:
- result = ''
- seconds = message.seconds + int(message.nanos / 1e9)
- nanos = message.nanos % 1e9
- result += '%d' % seconds
- if (nanos % 1e9) == 0:
- # If there are 0 fractional digits, the fractional
- # point '.' should be omitted when serializing.
- return result + 's'
- if (nanos % 1e6) == 0:
- # Serialize 3 fractional digits.
- return result + '.%03ds' % (nanos / 1e6)
- if (nanos % 1e3) == 0:
- # Serialize 6 fractional digits.
- return result + '.%06ds' % (nanos / 1e3)
- # Serialize 9 fractional digits.
- return result + '.%09ds' % nanos
-
-
-def _IsFieldMaskMessage(message_descriptor):
- return (message_descriptor.name == 'FieldMask' and
- message_descriptor.file.name == 'google/protobuf/field_mask.proto')
-
-
-def _FieldMaskMessageToJsonObject(message):
- """Converts FieldMask message according to Proto3 JSON Specification."""
- result = ''
- first = True
- for path in message.paths:
- if not first:
- result += ','
- result += path
- first = False
- return result
-
-
def _IsWrapperMessage(message_descriptor):
return message_descriptor.file.name == 'google/protobuf/wrappers.proto'
def _WrapperMessageToJsonObject(message):
- return _ConvertFieldToJsonObject(
+ return _FieldToJsonObject(
message.DESCRIPTOR.fields_by_name['value'], message.value)
@@ -285,7 +213,7 @@ def _DuplicateChecker(js):
result = {}
for name, value in js:
if name in result:
- raise ParseError('Failed to load JSON: duplicate key ' + name)
+ raise ParseError('Failed to load JSON: duplicate key {0}.'.format(name))
result[name] = value
return result
@@ -303,7 +231,7 @@ def Parse(text, message):
Raises::
ParseError: On JSON parsing problems.
"""
- if not isinstance(text, _UNICODETYPE): text = text.decode('utf-8')
+ if not isinstance(text, text_type): text = text.decode('utf-8')
try:
if sys.version_info < (2, 7):
# object_pair_hook is not supported before python2.7
@@ -311,7 +239,7 @@ def Parse(text, message):
else:
js = json.loads(text, object_pairs_hook=_DuplicateChecker)
except ValueError as e:
- raise ParseError('Failed to load JSON: ' + str(e))
+ raise ParseError('Failed to load JSON: {0}.'.format(str(e)))
_ConvertFieldValuePair(js, message)
return message
@@ -362,7 +290,7 @@ def _ConvertFieldValuePair(js, message):
message.ClearField(field.name)
if not isinstance(value, list):
raise ParseError('repeated field {0} must be in [] which is '
- '{1}'.format(name, value))
+ '{1}.'.format(name, value))
for item in value:
if item is None:
continue
@@ -383,9 +311,9 @@ def _ConvertFieldValuePair(js, message):
else:
raise ParseError(str(e))
except ValueError as e:
- raise ParseError('Failed to parse {0} field: {1}'.format(name, e))
+ raise ParseError('Failed to parse {0} field: {1}.'.format(name, e))
except TypeError as e:
- raise ParseError('Failed to parse {0} field: {1}'.format(name, e))
+ raise ParseError('Failed to parse {0} field: {1}.'.format(name, e))
def _ConvertMessage(value, message):
@@ -399,88 +327,13 @@ def _ConvertMessage(value, message):
ParseError: In case of convert problems.
"""
message_descriptor = message.DESCRIPTOR
- if _IsTimestampMessage(message_descriptor):
- _ConvertTimestampMessage(value, message)
- elif _IsDurationMessage(message_descriptor):
- _ConvertDurationMessage(value, message)
- elif _IsFieldMaskMessage(message_descriptor):
- _ConvertFieldMaskMessage(value, message)
+ if hasattr(message, 'FromJsonString'):
+ message.FromJsonString(value)
elif _IsWrapperMessage(message_descriptor):
_ConvertWrapperMessage(value, message)
else:
_ConvertFieldValuePair(value, message)
-
-def _ConvertTimestampMessage(value, message):
- """Convert a JSON representation into Timestamp message."""
- timezone_offset = value.find('Z')
- if timezone_offset == -1:
- timezone_offset = value.find('+')
- if timezone_offset == -1:
- timezone_offset = value.rfind('-')
- if timezone_offset == -1:
- raise ParseError(
- 'Failed to parse timestamp: missing valid timezone offset.')
- time_value = value[0:timezone_offset]
- # Parse datetime and nanos
- point_position = time_value.find('.')
- if point_position == -1:
- second_value = time_value
- nano_value = ''
- else:
- second_value = time_value[:point_position]
- nano_value = time_value[point_position + 1:]
- date_object = datetime.strptime(second_value, _TIMESTAMPFOMAT)
- td = date_object - datetime(1970, 1, 1)
- seconds = td.seconds + td.days * 24 * 3600
- if len(nano_value) > 9:
- raise ParseError(
- 'Failed to parse Timestamp: nanos {0} more than '
- '9 fractional digits.'.format(nano_value))
- if nano_value:
- nanos = round(float('0.' + nano_value) * 1e9)
- else:
- nanos = 0
- # Parse timezone offsets
- if value[timezone_offset] == 'Z':
- if len(value) != timezone_offset + 1:
- raise ParseError(
- 'Failed to parse timestamp: invalid trailing data {0}.'.format(value))
- else:
- timezone = value[timezone_offset:]
- pos = timezone.find(':')
- if pos == -1:
- raise ParseError(
- 'Invalid timezone offset value: ' + timezone)
- if timezone[0] == '+':
- seconds += (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60
- else:
- seconds -= (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60
- # Set seconds and nanos
- message.seconds = int(seconds)
- message.nanos = int(nanos)
-
-
-def _ConvertDurationMessage(value, message):
- """Convert a JSON representation into Duration message."""
- if value[-1] != 's':
- raise ParseError(
- 'Duration must end with letter "s": ' + value)
- try:
- duration = float(value[:-1])
- except ValueError:
- raise ParseError(
- 'Couldn\'t parse duration: ' + value)
- message.seconds = int(duration)
- message.nanos = int(round((duration - message.seconds) * 1e9))
-
-
-def _ConvertFieldMaskMessage(value, message):
- """Convert a JSON representation into FieldMask message."""
- for path in value.split(','):
- message.paths.append(path)
-
-
def _ConvertWrapperMessage(value, message):
"""Convert a JSON representation into Wrapper message."""
field = message.DESCRIPTOR.fields_by_name['value']
@@ -512,13 +365,13 @@ def _ConvertMapFieldValue(value, message, field):
value[key], value_field)
-def _ConvertScalarFieldValue(value, field, require_quote=False):
+def _ConvertScalarFieldValue(value, field, require_str=False):
"""Convert a single scalar field value.
Args:
value: A scalar value to convert the scalar field value.
field: The descriptor of the field to convert.
- require_quote: If True, '"' is required for the field value.
+ require_str: If True, the field value must be a str.
Returns:
The converted scalar field value
@@ -531,7 +384,7 @@ def _ConvertScalarFieldValue(value, field, require_quote=False):
elif field.cpp_type in _FLOAT_TYPES:
return _ConvertFloat(value)
elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL:
- return _ConvertBool(value, require_quote)
+ return _ConvertBool(value, require_str)
elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING:
if field.type == descriptor.FieldDescriptor.TYPE_BYTES:
return base64.b64decode(value)
@@ -561,10 +414,10 @@ def _ConvertInteger(value):
ParseError: If an integer couldn't be consumed.
"""
if isinstance(value, float):
- raise ParseError('Couldn\'t parse integer: {0}'.format(value))
+ raise ParseError('Couldn\'t parse integer: {0}.'.format(value))
- if isinstance(value, _UNICODETYPE) and not _INTEGER.match(value):
- raise ParseError('Couldn\'t parse integer: "{0}"'.format(value))
+ if isinstance(value, text_type) and value.find(' ') != -1:
+ raise ParseError('Couldn\'t parse integer: "{0}".'.format(value))
return int(value)
@@ -572,28 +425,28 @@ def _ConvertInteger(value):
def _ConvertFloat(value):
"""Convert an floating point number."""
if value == 'nan':
- raise ParseError('Couldn\'t parse float "nan", use "NaN" instead')
+ raise ParseError('Couldn\'t parse float "nan", use "NaN" instead.')
try:
# Assume Python compatible syntax.
return float(value)
except ValueError:
# Check alternative spellings.
- if value == '-Infinity':
+ if value == _NEG_INFINITY:
return float('-inf')
- elif value == 'Infinity':
+ elif value == _INFINITY:
return float('inf')
- elif value == 'NaN':
+ elif value == _NAN:
return float('nan')
else:
- raise ParseError('Couldn\'t parse float: {0}'.format(value))
+ raise ParseError('Couldn\'t parse float: {0}.'.format(value))
-def _ConvertBool(value, require_quote):
+def _ConvertBool(value, require_str):
"""Convert a boolean value.
Args:
value: A scalar value to convert.
- require_quote: If True, '"' is required for the boolean value.
+ require_str: If True, value must be a str.
Returns:
The bool parsed.
@@ -601,13 +454,13 @@ def _ConvertBool(value, require_quote):
Raises:
ParseError: If a boolean value couldn't be consumed.
"""
- if require_quote:
+ if require_str:
if value == 'true':
return True
elif value == 'false':
return False
else:
- raise ParseError('Expect "true" or "false", not {0}.'.format(value))
+ raise ParseError('Expected "true" or "false", not {0}.'.format(value))
if not isinstance(value, bool):
raise ParseError('Expected true or false without quotes.')
diff --git a/python/google/protobuf/message_factory.py b/python/google/protobuf/message_factory.py
index 9cd9c2a8..1b059d13 100644
--- a/python/google/protobuf/message_factory.py
+++ b/python/google/protobuf/message_factory.py
@@ -39,7 +39,6 @@ my_proto_instance = message_classes['some.proto.package.MessageName']()
__author__ = 'matthewtoia@google.com (Matt Toia)'
-from google.protobuf import descriptor_database
from google.protobuf import descriptor_pool
from google.protobuf import message
from google.protobuf import reflection
@@ -50,8 +49,7 @@ class MessageFactory(object):
def __init__(self, pool=None):
"""Initializes a new factory."""
- self.pool = (pool or descriptor_pool.DescriptorPool(
- descriptor_database.DescriptorDatabase()))
+ self.pool = pool or descriptor_pool.DescriptorPool()
# local cache of all classes built from protobuf descriptors
self._classes = {}
diff --git a/python/google/protobuf/pyext/descriptor.cc b/python/google/protobuf/pyext/descriptor.cc
index 61a3d237..a875a7be 100644
--- a/python/google/protobuf/pyext/descriptor.cc
+++ b/python/google/protobuf/pyext/descriptor.cc
@@ -203,6 +203,14 @@ static PyObject* GetOrBuildOptions(const DescriptorClass *descriptor) {
PyObject* message_class(cdescriptor_pool::GetMessageClass(
pool, message_type));
if (message_class == NULL) {
+ // The Options message was not found in the current DescriptorPool.
+ // In this case, there cannot be extensions to these options, and we can
+ // try to use the basic pool instead.
+ PyErr_Clear();
+ message_class = cdescriptor_pool::GetMessageClass(
+ GetDefaultDescriptorPool(), message_type);
+ }
+ if (message_class == NULL) {
PyErr_Format(PyExc_TypeError, "Could not retrieve class for Options: %s",
message_type->full_name().c_str());
return NULL;
@@ -211,6 +219,12 @@ static PyObject* GetOrBuildOptions(const DescriptorClass *descriptor) {
if (value == NULL) {
return NULL;
}
+ if (!PyObject_TypeCheck(value.get(), &CMessage_Type)) {
+ PyErr_Format(PyExc_TypeError, "Invalid class for %s: %s",
+ message_type->full_name().c_str(),
+ Py_TYPE(value.get())->tp_name);
+ return NULL;
+ }
CMessage* cmsg = reinterpret_cast<CMessage*>(value.get());
const Reflection* reflection = options.GetReflection();
@@ -327,7 +341,8 @@ PyObject* NewInternedDescriptor(PyTypeObject* type,
PyDescriptorPool* pool = GetDescriptorPool_FromPool(
GetFileDescriptor(descriptor)->pool());
if (pool == NULL) {
- Py_DECREF(py_descriptor);
+ // Don't DECREF, the object is not fully initialized.
+ PyObject_Del(py_descriptor);
return NULL;
}
Py_INCREF(pool);
@@ -1213,6 +1228,13 @@ static void Dealloc(PyFileDescriptor* self) {
descriptor::Dealloc(&self->base);
}
+static PyObject* GetPool(PyFileDescriptor *self, void *closure) {
+ PyObject* pool = reinterpret_cast<PyObject*>(
+ GetDescriptorPool_FromPool(_GetDescriptor(self)->pool()));
+ Py_XINCREF(pool);
+ return pool;
+}
+
static PyObject* GetName(PyFileDescriptor *self, void *closure) {
return PyString_FromCppString(_GetDescriptor(self)->name());
}
@@ -1292,6 +1314,7 @@ static PyObject* CopyToProto(PyFileDescriptor *self, PyObject *target) {
}
static PyGetSetDef Getters[] = {
+ { "pool", (getter)GetPool, NULL, "pool"},
{ "name", (getter)GetName, NULL, "name"},
{ "package", (getter)GetPackage, NULL, "package"},
{ "serialized_pb", (getter)GetSerializedPb},
@@ -1354,8 +1377,8 @@ PyTypeObject PyFileDescriptor_Type = {
0, // tp_descr_set
0, // tp_dictoffset
0, // tp_init
- PyType_GenericAlloc, // tp_alloc
- PyType_GenericNew, // tp_new
+ 0, // tp_alloc
+ 0, // tp_new
PyObject_Del, // tp_free
};
diff --git a/python/google/protobuf/pyext/descriptor_database.cc b/python/google/protobuf/pyext/descriptor_database.cc
new file mode 100644
index 00000000..514722b4
--- /dev/null
+++ b/python/google/protobuf/pyext/descriptor_database.cc
@@ -0,0 +1,145 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// This file defines a C++ DescriptorDatabase, which wraps a Python Database
+// and delegate all its operations to Python methods.
+
+#include <google/protobuf/pyext/descriptor_database.h>
+
+#include <google/protobuf/stubs/logging.h>
+#include <google/protobuf/stubs/common.h>
+#include <google/protobuf/descriptor.pb.h>
+#include <google/protobuf/pyext/message.h>
+#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
+
+namespace google {
+namespace protobuf {
+namespace python {
+
+PyDescriptorDatabase::PyDescriptorDatabase(PyObject* py_database)
+ : py_database_(py_database) {
+ Py_INCREF(py_database_);
+}
+
+PyDescriptorDatabase::~PyDescriptorDatabase() { Py_DECREF(py_database_); }
+
+// Convert a Python object to a FileDescriptorProto pointer.
+// Handles all kinds of Python errors, which are simply logged.
+static bool GetFileDescriptorProto(PyObject* py_descriptor,
+ FileDescriptorProto* output) {
+ if (py_descriptor == NULL) {
+ if (PyErr_ExceptionMatches(PyExc_KeyError)) {
+ // Expected error: item was simply not found.
+ PyErr_Clear();
+ } else {
+ GOOGLE_LOG(ERROR) << "DescriptorDatabase method raised an error";
+ PyErr_Print();
+ }
+ return false;
+ }
+ const Descriptor* filedescriptor_descriptor =
+ FileDescriptorProto::default_instance().GetDescriptor();
+ CMessage* message = reinterpret_cast<CMessage*>(py_descriptor);
+ if (PyObject_TypeCheck(py_descriptor, &CMessage_Type) &&
+ message->message->GetDescriptor() == filedescriptor_descriptor) {
+ // Fast path: Just use the pointer.
+ FileDescriptorProto* file_proto =
+ static_cast<FileDescriptorProto*>(message->message);
+ *output = *file_proto;
+ return true;
+ } else {
+ // Slow path: serialize the message. This allows to use databases which
+ // use a different implementation of FileDescriptorProto.
+ ScopedPyObjectPtr serialized_pb(
+ PyObject_CallMethod(py_descriptor, "SerializeToString", NULL));
+ if (serialized_pb == NULL) {
+ GOOGLE_LOG(ERROR)
+ << "DescriptorDatabase method did not return a FileDescriptorProto";
+ PyErr_Print();
+ return false;
+ }
+ char* str;
+ Py_ssize_t len;
+ if (PyBytes_AsStringAndSize(serialized_pb.get(), &str, &len) < 0) {
+ GOOGLE_LOG(ERROR)
+ << "DescriptorDatabase method did not return a FileDescriptorProto";
+ PyErr_Print();
+ return false;
+ }
+ FileDescriptorProto file_proto;
+ if (!file_proto.ParseFromArray(str, len)) {
+ GOOGLE_LOG(ERROR)
+ << "DescriptorDatabase method did not return a FileDescriptorProto";
+ return false;
+ }
+ *output = file_proto;
+ return true;
+ }
+}
+
+// Find a file by file name.
+bool PyDescriptorDatabase::FindFileByName(const string& filename,
+ FileDescriptorProto* output) {
+ ScopedPyObjectPtr py_descriptor(PyObject_CallMethod(
+ py_database_, "FindFileByName", "s#", filename.c_str(), filename.size()));
+ return GetFileDescriptorProto(py_descriptor.get(), output);
+}
+
+// Find the file that declares the given fully-qualified symbol name.
+bool PyDescriptorDatabase::FindFileContainingSymbol(
+ const string& symbol_name, FileDescriptorProto* output) {
+ ScopedPyObjectPtr py_descriptor(
+ PyObject_CallMethod(py_database_, "FindFileContainingSymbol", "s#",
+ symbol_name.c_str(), symbol_name.size()));
+ return GetFileDescriptorProto(py_descriptor.get(), output);
+}
+
+// Find the file which defines an extension extending the given message type
+// with the given field number.
+// Python DescriptorDatabases are not required to implement this method.
+bool PyDescriptorDatabase::FindFileContainingExtension(
+ const string& containing_type, int field_number,
+ FileDescriptorProto* output) {
+ ScopedPyObjectPtr py_method(
+ PyObject_GetAttrString(py_database_, "FindFileContainingExtension"));
+ if (py_method == NULL) {
+ // This method is not implemented, returns without error.
+ PyErr_Clear();
+ return false;
+ }
+ ScopedPyObjectPtr py_descriptor(
+ PyObject_CallFunction(py_method.get(), "s#i", containing_type.c_str(),
+ containing_type.size(), field_number));
+ return GetFileDescriptorProto(py_descriptor.get(), output);
+}
+
+} // namespace python
+} // namespace protobuf
+} // namespace google
diff --git a/python/google/protobuf/pyext/descriptor_database.h b/python/google/protobuf/pyext/descriptor_database.h
new file mode 100644
index 00000000..fc71c4bc
--- /dev/null
+++ b/python/google/protobuf/pyext/descriptor_database.h
@@ -0,0 +1,75 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_DATABASE_H__
+#define GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_DATABASE_H__
+
+#include <Python.h>
+
+#include <google/protobuf/descriptor_database.h>
+
+namespace google {
+namespace protobuf {
+namespace python {
+
+class PyDescriptorDatabase : public DescriptorDatabase {
+ public:
+ explicit PyDescriptorDatabase(PyObject* py_database);
+ ~PyDescriptorDatabase();
+
+ // Implement the abstract interface. All these functions fill the output
+ // with a copy of FileDescriptorProto.
+
+ // Find a file by file name.
+ bool FindFileByName(const string& filename,
+ FileDescriptorProto* output);
+
+ // Find the file that declares the given fully-qualified symbol name.
+ bool FindFileContainingSymbol(const string& symbol_name,
+ FileDescriptorProto* output);
+
+ // Find the file which defines an extension extending the given message type
+ // with the given field number.
+ // Containing_type must be a fully-qualified type name.
+ // Python objects are not required to implement this method.
+ bool FindFileContainingExtension(const string& containing_type,
+ int field_number,
+ FileDescriptorProto* output);
+
+ private:
+ // The python object that implements the database. The reference is owned.
+ PyObject* py_database_;
+};
+
+} // namespace python
+} // namespace protobuf
+
+} // namespace google
+#endif // GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_DATABASE_H__
diff --git a/python/google/protobuf/pyext/descriptor_pool.cc b/python/google/protobuf/pyext/descriptor_pool.cc
index 0f7487fa..0bc76bc9 100644
--- a/python/google/protobuf/pyext/descriptor_pool.cc
+++ b/python/google/protobuf/pyext/descriptor_pool.cc
@@ -34,8 +34,9 @@
#include <google/protobuf/descriptor.pb.h>
#include <google/protobuf/dynamic_message.h>
-#include <google/protobuf/pyext/descriptor_pool.h>
#include <google/protobuf/pyext/descriptor.h>
+#include <google/protobuf/pyext/descriptor_database.h>
+#include <google/protobuf/pyext/descriptor_pool.h>
#include <google/protobuf/pyext/message.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
@@ -60,38 +61,93 @@ static hash_map<const DescriptorPool*, PyDescriptorPool*> descriptor_pool_map;
namespace cdescriptor_pool {
-static PyDescriptorPool* NewDescriptorPool() {
- PyDescriptorPool* cdescriptor_pool = PyObject_New(
+// Create a Python DescriptorPool object, but does not fill the "pool"
+// attribute.
+static PyDescriptorPool* _CreateDescriptorPool() {
+ PyDescriptorPool* cpool = PyObject_New(
PyDescriptorPool, &PyDescriptorPool_Type);
- if (cdescriptor_pool == NULL) {
+ if (cpool == NULL) {
return NULL;
}
- // Build a DescriptorPool for messages only declared in Python libraries.
- // generated_pool() contains all messages linked in C++ libraries, and is used
- // as underlay.
- cdescriptor_pool->pool = new DescriptorPool(DescriptorPool::generated_pool());
+ cpool->underlay = NULL;
+ cpool->database = NULL;
DynamicMessageFactory* message_factory = new DynamicMessageFactory();
// This option might be the default some day.
message_factory->SetDelegateToGeneratedFactory(true);
- cdescriptor_pool->message_factory = message_factory;
+ cpool->message_factory = message_factory;
// TODO(amauryfa): Rewrite the SymbolDatabase in C so that it uses the same
// storage.
- cdescriptor_pool->classes_by_descriptor =
+ cpool->classes_by_descriptor =
new PyDescriptorPool::ClassesByMessageMap();
- cdescriptor_pool->descriptor_options =
+ cpool->descriptor_options =
new hash_map<const void*, PyObject *>();
+ return cpool;
+}
+
+// Create a Python DescriptorPool, using the given pool as an underlay:
+// new messages will be added to a custom pool, not to the underlay.
+//
+// Ownership of the underlay is not transferred, its pointer should
+// stay alive.
+static PyDescriptorPool* PyDescriptorPool_NewWithUnderlay(
+ const DescriptorPool* underlay) {
+ PyDescriptorPool* cpool = _CreateDescriptorPool();
+ if (cpool == NULL) {
+ return NULL;
+ }
+ cpool->pool = new DescriptorPool(underlay);
+ cpool->underlay = underlay;
+
if (!descriptor_pool_map.insert(
- std::make_pair(cdescriptor_pool->pool, cdescriptor_pool)).second) {
+ std::make_pair(cpool->pool, cpool)).second) {
// Should never happen -- would indicate an internal error / bug.
PyErr_SetString(PyExc_ValueError, "DescriptorPool already registered");
return NULL;
}
- return cdescriptor_pool;
+ return cpool;
+}
+
+static PyDescriptorPool* PyDescriptorPool_NewWithDatabase(
+ DescriptorDatabase* database) {
+ PyDescriptorPool* cpool = _CreateDescriptorPool();
+ if (cpool == NULL) {
+ return NULL;
+ }
+ if (database != NULL) {
+ cpool->pool = new DescriptorPool(database);
+ cpool->database = database;
+ } else {
+ cpool->pool = new DescriptorPool();
+ }
+
+ if (!descriptor_pool_map.insert(std::make_pair(cpool->pool, cpool)).second) {
+ // Should never happen -- would indicate an internal error / bug.
+ PyErr_SetString(PyExc_ValueError, "DescriptorPool already registered");
+ return NULL;
+ }
+
+ return cpool;
+}
+
+// The public DescriptorPool constructor.
+static PyObject* New(PyTypeObject* type,
+ PyObject* args, PyObject* kwargs) {
+ static char* kwlist[] = {"descriptor_db", 0};
+ PyObject* py_database = NULL;
+ if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O", kwlist, &py_database)) {
+ return NULL;
+ }
+ DescriptorDatabase* database = NULL;
+ if (py_database && py_database != Py_None) {
+ database = new PyDescriptorDatabase(py_database);
+ }
+ return reinterpret_cast<PyObject*>(
+ PyDescriptorPool_NewWithDatabase(database));
}
static void Dealloc(PyDescriptorPool* self) {
@@ -108,8 +164,9 @@ static void Dealloc(PyDescriptorPool* self) {
Py_DECREF(it->second);
}
delete self->descriptor_options;
- delete self->pool;
delete self->message_factory;
+ delete self->database;
+ delete self->pool;
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
}
@@ -354,6 +411,14 @@ PyObject* AddSerializedFile(PyDescriptorPool* self, PyObject* serialized_pb) {
char* message_type;
Py_ssize_t message_len;
+ if (self->database != NULL) {
+ PyErr_SetString(
+ PyExc_ValueError,
+ "Cannot call Add on a DescriptorPool that uses a DescriptorDatabase. "
+ "Add your file to the underlying database.");
+ return NULL;
+ }
+
if (PyBytes_AsStringAndSize(serialized_pb, &message_type, &message_len) < 0) {
return NULL;
}
@@ -366,8 +431,10 @@ PyObject* AddSerializedFile(PyDescriptorPool* self, PyObject* serialized_pb) {
// If the file was already part of a C++ library, all its descriptors are in
// the underlying pool. No need to do anything else.
- const FileDescriptor* generated_file =
- DescriptorPool::generated_pool()->FindFileByName(file_proto.name());
+ const FileDescriptor* generated_file = NULL;
+ if (self->underlay) {
+ generated_file = self->underlay->FindFileByName(file_proto.name());
+ }
if (generated_file != NULL) {
return PyFileDescriptor_FromDescriptorWithSerializedPb(
generated_file, serialized_pb);
@@ -470,7 +537,7 @@ PyTypeObject PyDescriptorPool_Type = {
0, // tp_dictoffset
0, // tp_init
0, // tp_alloc
- 0, // tp_new
+ cdescriptor_pool::New, // tp_new
PyObject_Del, // tp_free
};
@@ -482,7 +549,11 @@ bool InitDescriptorPool() {
if (PyType_Ready(&PyDescriptorPool_Type) < 0)
return false;
- python_generated_pool = cdescriptor_pool::NewDescriptorPool();
+ // The Pool of messages declared in Python libraries.
+ // generated_pool() contains all messages already linked in C++ libraries, and
+ // is used as underlay.
+ python_generated_pool = cdescriptor_pool::PyDescriptorPool_NewWithUnderlay(
+ DescriptorPool::generated_pool());
if (python_generated_pool == NULL) {
return false;
}
@@ -494,6 +565,10 @@ bool InitDescriptorPool() {
return true;
}
+// The default DescriptorPool used everywhere in this module.
+// Today it's the python_generated_pool.
+// TODO(amauryfa): Remove all usages of this function: the pool should be
+// derived from the context.
PyDescriptorPool* GetDefaultDescriptorPool() {
return python_generated_pool;
}
diff --git a/python/google/protobuf/pyext/descriptor_pool.h b/python/google/protobuf/pyext/descriptor_pool.h
index eda73d38..16bc910c 100644
--- a/python/google/protobuf/pyext/descriptor_pool.h
+++ b/python/google/protobuf/pyext/descriptor_pool.h
@@ -55,8 +55,17 @@ namespace python {
typedef struct PyDescriptorPool {
PyObject_HEAD
+ // The C++ pool containing Descriptors.
DescriptorPool* pool;
+ // The C++ pool acting as an underlay. Can be NULL.
+ // This pointer is not owned and must stay alive.
+ const DescriptorPool* underlay;
+
+ // The C++ descriptor database used to fetch unknown protos. Can be NULL.
+ // This pointer is owned.
+ const DescriptorDatabase* database;
+
// DynamicMessageFactory used to create C++ instances of messages.
// This object cache the descriptors that were used, so the DescriptorPool
// needs to get rid of it before it can delete itself.
@@ -138,6 +147,7 @@ PyObject* FindOneofByName(PyDescriptorPool* self, PyObject* arg);
// Retrieve the global descriptor pool owned by the _message module.
// This is the one used by pb2.py generated modules.
// Returns a *borrowed* reference.
+// "Default" pool used to register messages from _pb2.py modules.
PyDescriptorPool* GetDefaultDescriptorPool();
// Retrieve the python descriptor pool owning a C++ descriptor pool.
diff --git a/python/google/protobuf/pyext/extension_dict.cc b/python/google/protobuf/pyext/extension_dict.cc
index b361b342..555bd293 100644
--- a/python/google/protobuf/pyext/extension_dict.cc
+++ b/python/google/protobuf/pyext/extension_dict.cc
@@ -94,13 +94,13 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) {
if (descriptor == NULL) {
return NULL;
}
- if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) {
+ if (!CheckFieldBelongsToMessage(descriptor, self->message)) {
return NULL;
}
if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
- return cmessage::InternalGetScalar(self->parent->message, descriptor);
+ return cmessage::InternalGetScalar(self->message, descriptor);
}
PyObject* value = PyDict_GetItem(self->values, key);
@@ -109,6 +109,14 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) {
return value;
}
+ if (self->parent == NULL) {
+ // We are in "detached" state. Don't allow further modifications.
+ // TODO(amauryfa): Support adding non-scalars to a detached extension dict.
+ // This probably requires to store the type of the main message.
+ PyErr_SetObject(PyExc_KeyError, key);
+ return NULL;
+ }
+
if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
PyObject* sub_message = cmessage::InternalGetSubMessage(
@@ -154,7 +162,7 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) {
if (descriptor == NULL) {
return -1;
}
- if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) {
+ if (!CheckFieldBelongsToMessage(descriptor, self->message)) {
return -1;
}
@@ -164,9 +172,11 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) {
"type");
return -1;
}
- cmessage::AssureWritable(self->parent);
- if (cmessage::InternalSetScalar(self->parent, descriptor, value) < 0) {
- return -1;
+ if (self->parent) {
+ cmessage::AssureWritable(self->parent);
+ if (cmessage::InternalSetScalar(self->parent, descriptor, value) < 0) {
+ return -1;
+ }
}
// TODO(tibell): We shouldn't write scalars to the cache.
PyDict_SetItem(self->values, key, value);
@@ -180,15 +190,17 @@ PyObject* ClearExtension(ExtensionDict* self, PyObject* extension) {
return NULL;
}
PyObject* value = PyDict_GetItem(self->values, extension);
- if (value != NULL) {
- if (ReleaseExtension(self, value, descriptor) < 0) {
+ if (self->parent) {
+ if (value != NULL) {
+ if (ReleaseExtension(self, value, descriptor) < 0) {
+ return NULL;
+ }
+ }
+ if (ScopedPyObjectPtr(cmessage::ClearFieldByDescriptor(
+ self->parent, descriptor)) == NULL) {
return NULL;
}
}
- if (ScopedPyObjectPtr(cmessage::ClearFieldByDescriptor(
- self->parent, descriptor)) == NULL) {
- return NULL;
- }
if (PyDict_DelItem(self->values, extension) < 0) {
PyErr_Clear();
}
@@ -201,8 +213,15 @@ PyObject* HasExtension(ExtensionDict* self, PyObject* extension) {
if (descriptor == NULL) {
return NULL;
}
- PyObject* result = cmessage::HasFieldByDescriptor(self->parent, descriptor);
- return result;
+ if (self->parent) {
+ return cmessage::HasFieldByDescriptor(self->parent, descriptor);
+ } else {
+ int exists = PyDict_Contains(self->values, extension);
+ if (exists < 0) {
+ return NULL;
+ }
+ return PyBool_FromLong(exists);
+ }
}
PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name) {
diff --git a/python/google/protobuf/pyext/extension_dict.h b/python/google/protobuf/pyext/extension_dict.h
index 7a66cb23..d92cf956 100644
--- a/python/google/protobuf/pyext/extension_dict.h
+++ b/python/google/protobuf/pyext/extension_dict.h
@@ -117,11 +117,6 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value);
PyObject* ClearExtension(ExtensionDict* self,
PyObject* extension);
-// Checks if the dict has an extension.
-//
-// Returns a new python boolean reference.
-PyObject* HasExtension(ExtensionDict* self, PyObject* extension);
-
// Gets an extension from the dict given the extension name as opposed to
// descriptor.
//
diff --git a/python/google/protobuf/pyext/map_container.cc b/python/google/protobuf/pyext/map_container.cc
new file mode 100644
index 00000000..c39f7b83
--- /dev/null
+++ b/python/google/protobuf/pyext/map_container.cc
@@ -0,0 +1,912 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: haberman@google.com (Josh Haberman)
+
+#include <google/protobuf/pyext/map_container.h>
+
+#include <google/protobuf/stubs/logging.h>
+#include <google/protobuf/stubs/common.h>
+#include <google/protobuf/stubs/scoped_ptr.h>
+#include <google/protobuf/map_field.h>
+#include <google/protobuf/map.h>
+#include <google/protobuf/message.h>
+#include <google/protobuf/pyext/message.h>
+#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
+
+#if PY_MAJOR_VERSION >= 3
+ #define PyInt_FromLong PyLong_FromLong
+ #define PyInt_FromSize_t PyLong_FromSize_t
+#endif
+
+namespace google {
+namespace protobuf {
+namespace python {
+
+// Functions that need access to map reflection functionality.
+// They need to be contained in this class because it is friended.
+class MapReflectionFriend {
+ public:
+ // Methods that are in common between the map types.
+ static PyObject* Contains(PyObject* _self, PyObject* key);
+ static Py_ssize_t Length(PyObject* _self);
+ static PyObject* GetIterator(PyObject *_self);
+ static PyObject* IterNext(PyObject* _self);
+
+ // Methods that differ between the map types.
+ static PyObject* ScalarMapGetItem(PyObject* _self, PyObject* key);
+ static PyObject* MessageMapGetItem(PyObject* _self, PyObject* key);
+ static int ScalarMapSetItem(PyObject* _self, PyObject* key, PyObject* v);
+ static int MessageMapSetItem(PyObject* _self, PyObject* key, PyObject* v);
+};
+
+struct MapIterator {
+ PyObject_HEAD;
+
+ scoped_ptr< ::google::protobuf::MapIterator> iter;
+
+ // A pointer back to the container, so we can notice changes to the version.
+ // We own a ref on this.
+ MapContainer* container;
+
+ // We need to keep a ref on the Message* too, because
+ // MapIterator::~MapIterator() accesses it. Normally this would be ok because
+ // the ref on container (above) would guarantee outlive semantics. However in
+ // the case of ClearField(), InitializeAndCopyToParentContainer() resets the
+ // message pointer (and the owner) to a different message, a copy of the
+ // original. But our iterator still points to the original, which could now
+ // get deleted before us.
+ //
+ // To prevent this, we ensure that the Message will always stay alive as long
+ // as this iterator does. This is solely for the benefit of the MapIterator
+ // destructor -- we should never actually access the iterator in this state
+ // except to delete it.
+ shared_ptr<Message> owner;
+
+ // The version of the map when we took the iterator to it.
+ //
+ // We store this so that if the map is modified during iteration we can throw
+ // an error.
+ uint64 version;
+
+ // True if the container is empty. We signal this separately to avoid calling
+ // any of the iteration methods, which are non-const.
+ bool empty;
+};
+
+Message* MapContainer::GetMutableMessage() {
+ cmessage::AssureWritable(parent);
+ return const_cast<Message*>(message);
+}
+
+// Consumes a reference on the Python string object.
+static bool PyStringToSTL(PyObject* py_string, string* stl_string) {
+ char *value;
+ Py_ssize_t value_len;
+
+ if (!py_string) {
+ return false;
+ }
+ if (PyBytes_AsStringAndSize(py_string, &value, &value_len) < 0) {
+ Py_DECREF(py_string);
+ return false;
+ } else {
+ stl_string->assign(value, value_len);
+ Py_DECREF(py_string);
+ return true;
+ }
+}
+
+static bool PythonToMapKey(PyObject* obj,
+ const FieldDescriptor* field_descriptor,
+ MapKey* key) {
+ switch (field_descriptor->cpp_type()) {
+ case FieldDescriptor::CPPTYPE_INT32: {
+ GOOGLE_CHECK_GET_INT32(obj, value, false);
+ key->SetInt32Value(value);
+ break;
+ }
+ case FieldDescriptor::CPPTYPE_INT64: {
+ GOOGLE_CHECK_GET_INT64(obj, value, false);
+ key->SetInt64Value(value);
+ break;
+ }
+ case FieldDescriptor::CPPTYPE_UINT32: {
+ GOOGLE_CHECK_GET_UINT32(obj, value, false);
+ key->SetUInt32Value(value);
+ break;
+ }
+ case FieldDescriptor::CPPTYPE_UINT64: {
+ GOOGLE_CHECK_GET_UINT64(obj, value, false);
+ key->SetUInt64Value(value);
+ break;
+ }
+ case FieldDescriptor::CPPTYPE_BOOL: {
+ GOOGLE_CHECK_GET_BOOL(obj, value, false);
+ key->SetBoolValue(value);
+ break;
+ }
+ case FieldDescriptor::CPPTYPE_STRING: {
+ string str;
+ if (!PyStringToSTL(CheckString(obj, field_descriptor), &str)) {
+ return false;
+ }
+ key->SetStringValue(str);
+ break;
+ }
+ default:
+ PyErr_Format(
+ PyExc_SystemError, "Type %d cannot be a map key",
+ field_descriptor->cpp_type());
+ return false;
+ }
+ return true;
+}
+
+static PyObject* MapKeyToPython(const FieldDescriptor* field_descriptor,
+ const MapKey& key) {
+ switch (field_descriptor->cpp_type()) {
+ case FieldDescriptor::CPPTYPE_INT32:
+ return PyInt_FromLong(key.GetInt32Value());
+ case FieldDescriptor::CPPTYPE_INT64:
+ return PyLong_FromLongLong(key.GetInt64Value());
+ case FieldDescriptor::CPPTYPE_UINT32:
+ return PyInt_FromSize_t(key.GetUInt32Value());
+ case FieldDescriptor::CPPTYPE_UINT64:
+ return PyLong_FromUnsignedLongLong(key.GetUInt64Value());
+ case FieldDescriptor::CPPTYPE_BOOL:
+ return PyBool_FromLong(key.GetBoolValue());
+ case FieldDescriptor::CPPTYPE_STRING:
+ return ToStringObject(field_descriptor, key.GetStringValue());
+ default:
+ PyErr_Format(
+ PyExc_SystemError, "Couldn't convert type %d to value",
+ field_descriptor->cpp_type());
+ return NULL;
+ }
+}
+
+// This is only used for ScalarMap, so we don't need to handle the
+// CPPTYPE_MESSAGE case.
+PyObject* MapValueRefToPython(const FieldDescriptor* field_descriptor,
+ MapValueRef* value) {
+ switch (field_descriptor->cpp_type()) {
+ case FieldDescriptor::CPPTYPE_INT32:
+ return PyInt_FromLong(value->GetInt32Value());
+ case FieldDescriptor::CPPTYPE_INT64:
+ return PyLong_FromLongLong(value->GetInt64Value());
+ case FieldDescriptor::CPPTYPE_UINT32:
+ return PyInt_FromSize_t(value->GetUInt32Value());
+ case FieldDescriptor::CPPTYPE_UINT64:
+ return PyLong_FromUnsignedLongLong(value->GetUInt64Value());
+ case FieldDescriptor::CPPTYPE_FLOAT:
+ return PyFloat_FromDouble(value->GetFloatValue());
+ case FieldDescriptor::CPPTYPE_DOUBLE:
+ return PyFloat_FromDouble(value->GetDoubleValue());
+ case FieldDescriptor::CPPTYPE_BOOL:
+ return PyBool_FromLong(value->GetBoolValue());
+ case FieldDescriptor::CPPTYPE_STRING:
+ return ToStringObject(field_descriptor, value->GetStringValue());
+ case FieldDescriptor::CPPTYPE_ENUM:
+ return PyInt_FromLong(value->GetEnumValue());
+ default:
+ PyErr_Format(
+ PyExc_SystemError, "Couldn't convert type %d to value",
+ field_descriptor->cpp_type());
+ return NULL;
+ }
+}
+
+// This is only used for ScalarMap, so we don't need to handle the
+// CPPTYPE_MESSAGE case.
+static bool PythonToMapValueRef(PyObject* obj,
+ const FieldDescriptor* field_descriptor,
+ bool allow_unknown_enum_values,
+ MapValueRef* value_ref) {
+ switch (field_descriptor->cpp_type()) {
+ case FieldDescriptor::CPPTYPE_INT32: {
+ GOOGLE_CHECK_GET_INT32(obj, value, false);
+ value_ref->SetInt32Value(value);
+ return true;
+ }
+ case FieldDescriptor::CPPTYPE_INT64: {
+ GOOGLE_CHECK_GET_INT64(obj, value, false);
+ value_ref->SetInt64Value(value);
+ return true;
+ }
+ case FieldDescriptor::CPPTYPE_UINT32: {
+ GOOGLE_CHECK_GET_UINT32(obj, value, false);
+ value_ref->SetUInt32Value(value);
+ return true;
+ }
+ case FieldDescriptor::CPPTYPE_UINT64: {
+ GOOGLE_CHECK_GET_UINT64(obj, value, false);
+ value_ref->SetUInt64Value(value);
+ return true;
+ }
+ case FieldDescriptor::CPPTYPE_FLOAT: {
+ GOOGLE_CHECK_GET_FLOAT(obj, value, false);
+ value_ref->SetFloatValue(value);
+ return true;
+ }
+ case FieldDescriptor::CPPTYPE_DOUBLE: {
+ GOOGLE_CHECK_GET_DOUBLE(obj, value, false);
+ value_ref->SetDoubleValue(value);
+ return true;
+ }
+ case FieldDescriptor::CPPTYPE_BOOL: {
+ GOOGLE_CHECK_GET_BOOL(obj, value, false);
+ value_ref->SetBoolValue(value);
+ return true;;
+ }
+ case FieldDescriptor::CPPTYPE_STRING: {
+ string str;
+ if (!PyStringToSTL(CheckString(obj, field_descriptor), &str)) {
+ return false;
+ }
+ value_ref->SetStringValue(str);
+ return true;
+ }
+ case FieldDescriptor::CPPTYPE_ENUM: {
+ GOOGLE_CHECK_GET_INT32(obj, value, false);
+ if (allow_unknown_enum_values) {
+ value_ref->SetEnumValue(value);
+ return true;
+ } else {
+ const EnumDescriptor* enum_descriptor = field_descriptor->enum_type();
+ const EnumValueDescriptor* enum_value =
+ enum_descriptor->FindValueByNumber(value);
+ if (enum_value != NULL) {
+ value_ref->SetEnumValue(value);
+ return true;
+ } else {
+ PyErr_Format(PyExc_ValueError, "Unknown enum value: %d", value);
+ return false;
+ }
+ }
+ break;
+ }
+ default:
+ PyErr_Format(
+ PyExc_SystemError, "Setting value to a field of unknown type %d",
+ field_descriptor->cpp_type());
+ return false;
+ }
+}
+
+// Map methods common to ScalarMap and MessageMap //////////////////////////////
+
+static MapContainer* GetMap(PyObject* obj) {
+ return reinterpret_cast<MapContainer*>(obj);
+}
+
+Py_ssize_t MapReflectionFriend::Length(PyObject* _self) {
+ MapContainer* self = GetMap(_self);
+ const google::protobuf::Message* message = self->message;
+ return message->GetReflection()->MapSize(*message,
+ self->parent_field_descriptor);
+}
+
+PyObject* Clear(PyObject* _self) {
+ MapContainer* self = GetMap(_self);
+ Message* message = self->GetMutableMessage();
+ const Reflection* reflection = message->GetReflection();
+
+ reflection->ClearField(message, self->parent_field_descriptor);
+
+ Py_RETURN_NONE;
+}
+
+PyObject* MapReflectionFriend::Contains(PyObject* _self, PyObject* key) {
+ MapContainer* self = GetMap(_self);
+
+ const Message* message = self->message;
+ const Reflection* reflection = message->GetReflection();
+ MapKey map_key;
+
+ if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
+ return NULL;
+ }
+
+ if (reflection->ContainsMapKey(*message, self->parent_field_descriptor,
+ map_key)) {
+ Py_RETURN_TRUE;
+ } else {
+ Py_RETURN_FALSE;
+ }
+}
+
+// Initializes the underlying Message object of "to" so it becomes a new parent
+// repeated scalar, and copies all the values from "from" to it. A child scalar
+// container can be released by passing it as both from and to (e.g. making it
+// the recipient of the new parent message and copying the values from itself).
+static int InitializeAndCopyToParentContainer(MapContainer* from,
+ MapContainer* to) {
+ // For now we require from == to, re-evaluate if we want to support deep copy
+ // as in repeated_scalar_container.cc.
+ GOOGLE_DCHECK(from == to);
+ Message* new_message = from->message->New();
+
+ if (MapReflectionFriend::Length(reinterpret_cast<PyObject*>(from)) > 0) {
+ // A somewhat roundabout way of copying just one field from old_message to
+ // new_message. This is the best we can do with what Reflection gives us.
+ Message* mutable_old = from->GetMutableMessage();
+ vector<const FieldDescriptor*> fields;
+ fields.push_back(from->parent_field_descriptor);
+
+ // Move the map field into the new message.
+ mutable_old->GetReflection()->SwapFields(mutable_old, new_message, fields);
+
+ // If/when we support from != to, this will be required also to copy the
+ // map field back into the existing message:
+ // mutable_old->MergeFrom(*new_message);
+ }
+
+ // If from == to this could delete old_message.
+ to->owner.reset(new_message);
+
+ to->parent = NULL;
+ to->parent_field_descriptor = from->parent_field_descriptor;
+ to->message = new_message;
+
+ // Invalidate iterators, since they point to the old copy of the field.
+ to->version++;
+
+ return 0;
+}
+
+int MapContainer::Release() {
+ return InitializeAndCopyToParentContainer(this, this);
+}
+
+
+// ScalarMap ///////////////////////////////////////////////////////////////////
+
+PyObject *NewScalarMapContainer(
+ CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor) {
+ if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
+ return NULL;
+ }
+
+ ScopedPyObjectPtr obj(PyType_GenericAlloc(&ScalarMapContainer_Type, 0));
+ if (obj.get() == NULL) {
+ return PyErr_Format(PyExc_RuntimeError,
+ "Could not allocate new container.");
+ }
+
+ MapContainer* self = GetMap(obj.get());
+
+ self->message = parent->message;
+ self->parent = parent;
+ self->parent_field_descriptor = parent_field_descriptor;
+ self->owner = parent->owner;
+ self->version = 0;
+
+ self->key_field_descriptor =
+ parent_field_descriptor->message_type()->FindFieldByName("key");
+ self->value_field_descriptor =
+ parent_field_descriptor->message_type()->FindFieldByName("value");
+
+ if (self->key_field_descriptor == NULL ||
+ self->value_field_descriptor == NULL) {
+ return PyErr_Format(PyExc_KeyError,
+ "Map entry descriptor did not have key/value fields");
+ }
+
+ return obj.release();
+}
+
+PyObject* MapReflectionFriend::ScalarMapGetItem(PyObject* _self,
+ PyObject* key) {
+ MapContainer* self = GetMap(_self);
+
+ Message* message = self->GetMutableMessage();
+ const Reflection* reflection = message->GetReflection();
+ MapKey map_key;
+ MapValueRef value;
+
+ if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
+ return NULL;
+ }
+
+ if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
+ map_key, &value)) {
+ self->version++;
+ }
+
+ return MapValueRefToPython(self->value_field_descriptor, &value);
+}
+
+int MapReflectionFriend::ScalarMapSetItem(PyObject* _self, PyObject* key,
+ PyObject* v) {
+ MapContainer* self = GetMap(_self);
+
+ Message* message = self->GetMutableMessage();
+ const Reflection* reflection = message->GetReflection();
+ MapKey map_key;
+ MapValueRef value;
+
+ if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
+ return -1;
+ }
+
+ self->version++;
+
+ if (v) {
+ // Set item to v.
+ reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
+ map_key, &value);
+
+ return PythonToMapValueRef(v, self->value_field_descriptor,
+ reflection->SupportsUnknownEnumValues(), &value)
+ ? 0
+ : -1;
+ } else {
+ // Delete key from map.
+ if (reflection->DeleteMapValue(message, self->parent_field_descriptor,
+ map_key)) {
+ return 0;
+ } else {
+ PyErr_Format(PyExc_KeyError, "Key not present in map");
+ return -1;
+ }
+ }
+}
+
+static PyObject* ScalarMapGet(PyObject* self, PyObject* args) {
+ PyObject* key;
+ PyObject* default_value = NULL;
+ if (PyArg_ParseTuple(args, "O|O", &key, &default_value) < 0) {
+ return NULL;
+ }
+
+ ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
+ if (is_present.get() == NULL) {
+ return NULL;
+ }
+
+ if (PyObject_IsTrue(is_present.get())) {
+ return MapReflectionFriend::ScalarMapGetItem(self, key);
+ } else {
+ if (default_value != NULL) {
+ Py_INCREF(default_value);
+ return default_value;
+ } else {
+ Py_RETURN_NONE;
+ }
+ }
+}
+
+static void ScalarMapDealloc(PyObject* _self) {
+ MapContainer* self = GetMap(_self);
+ self->owner.reset();
+ Py_TYPE(_self)->tp_free(_self);
+}
+
+static PyMappingMethods ScalarMapMappingMethods = {
+ MapReflectionFriend::Length, // mp_length
+ MapReflectionFriend::ScalarMapGetItem, // mp_subscript
+ MapReflectionFriend::ScalarMapSetItem, // mp_ass_subscript
+};
+
+static PyMethodDef ScalarMapMethods[] = {
+ { "__contains__", MapReflectionFriend::Contains, METH_O,
+ "Tests whether a key is a member of the map." },
+ { "clear", (PyCFunction)Clear, METH_NOARGS,
+ "Removes all elements from the map." },
+ { "get", ScalarMapGet, METH_VARARGS,
+ "Gets the value for the given key if present, or otherwise a default" },
+ /*
+ { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
+ "Makes a deep copy of the class." },
+ { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
+ "Outputs picklable representation of the repeated field." },
+ */
+ {NULL, NULL},
+};
+
+PyTypeObject ScalarMapContainer_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ FULL_MODULE_NAME ".ScalarMapContainer", // tp_name
+ sizeof(MapContainer), // tp_basicsize
+ 0, // tp_itemsize
+ ScalarMapDealloc, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ 0, // tp_as_sequence
+ &ScalarMapMappingMethods, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT, // tp_flags
+ "A scalar map container", // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ 0, // tp_richcompare
+ 0, // tp_weaklistoffset
+ MapReflectionFriend::GetIterator, // tp_iter
+ 0, // tp_iternext
+ ScalarMapMethods, // tp_methods
+ 0, // tp_members
+ 0, // tp_getset
+ 0, // tp_base
+ 0, // tp_dict
+ 0, // tp_descr_get
+ 0, // tp_descr_set
+ 0, // tp_dictoffset
+ 0, // tp_init
+};
+
+
+// MessageMap //////////////////////////////////////////////////////////////////
+
+static MessageMapContainer* GetMessageMap(PyObject* obj) {
+ return reinterpret_cast<MessageMapContainer*>(obj);
+}
+
+static PyObject* GetCMessage(MessageMapContainer* self, Message* message) {
+ // Get or create the CMessage object corresponding to this message.
+ ScopedPyObjectPtr key(PyLong_FromVoidPtr(message));
+ PyObject* ret = PyDict_GetItem(self->message_dict, key.get());
+
+ if (ret == NULL) {
+ CMessage* cmsg = cmessage::NewEmptyMessage(self->subclass_init,
+ message->GetDescriptor());
+ ret = reinterpret_cast<PyObject*>(cmsg);
+
+ if (cmsg == NULL) {
+ return NULL;
+ }
+ cmsg->owner = self->owner;
+ cmsg->message = message;
+ cmsg->parent = self->parent;
+
+ if (PyDict_SetItem(self->message_dict, key.get(), ret) < 0) {
+ Py_DECREF(ret);
+ return NULL;
+ }
+ } else {
+ Py_INCREF(ret);
+ }
+
+ return ret;
+}
+
+PyObject* NewMessageMapContainer(
+ CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor,
+ PyObject* concrete_class) {
+ if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
+ return NULL;
+ }
+
+ PyObject* obj = PyType_GenericAlloc(&MessageMapContainer_Type, 0);
+ if (obj == NULL) {
+ return PyErr_Format(PyExc_RuntimeError,
+ "Could not allocate new container.");
+ }
+
+ MessageMapContainer* self = GetMessageMap(obj);
+
+ self->message = parent->message;
+ self->parent = parent;
+ self->parent_field_descriptor = parent_field_descriptor;
+ self->owner = parent->owner;
+ self->version = 0;
+
+ self->key_field_descriptor =
+ parent_field_descriptor->message_type()->FindFieldByName("key");
+ self->value_field_descriptor =
+ parent_field_descriptor->message_type()->FindFieldByName("value");
+
+ self->message_dict = PyDict_New();
+ if (self->message_dict == NULL) {
+ return PyErr_Format(PyExc_RuntimeError,
+ "Could not allocate message dict.");
+ }
+
+ Py_INCREF(concrete_class);
+ self->subclass_init = concrete_class;
+
+ if (self->key_field_descriptor == NULL ||
+ self->value_field_descriptor == NULL) {
+ Py_DECREF(obj);
+ return PyErr_Format(PyExc_KeyError,
+ "Map entry descriptor did not have key/value fields");
+ }
+
+ return obj;
+}
+
+int MapReflectionFriend::MessageMapSetItem(PyObject* _self, PyObject* key,
+ PyObject* v) {
+ if (v) {
+ PyErr_Format(PyExc_ValueError,
+ "Direct assignment of submessage not allowed");
+ return -1;
+ }
+
+ // Now we know that this is a delete, not a set.
+
+ MessageMapContainer* self = GetMessageMap(_self);
+ Message* message = self->GetMutableMessage();
+ const Reflection* reflection = message->GetReflection();
+ MapKey map_key;
+ MapValueRef value;
+
+ self->version++;
+
+ if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
+ return -1;
+ }
+
+ // Delete key from map.
+ if (reflection->DeleteMapValue(message, self->parent_field_descriptor,
+ map_key)) {
+ return 0;
+ } else {
+ PyErr_Format(PyExc_KeyError, "Key not present in map");
+ return -1;
+ }
+}
+
+PyObject* MapReflectionFriend::MessageMapGetItem(PyObject* _self,
+ PyObject* key) {
+ MessageMapContainer* self = GetMessageMap(_self);
+
+ Message* message = self->GetMutableMessage();
+ const Reflection* reflection = message->GetReflection();
+ MapKey map_key;
+ MapValueRef value;
+
+ if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
+ return NULL;
+ }
+
+ if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
+ map_key, &value)) {
+ self->version++;
+ }
+
+ return GetCMessage(self, value.MutableMessageValue());
+}
+
+PyObject* MessageMapGet(PyObject* self, PyObject* args) {
+ PyObject* key;
+ PyObject* default_value = NULL;
+ if (PyArg_ParseTuple(args, "O|O", &key, &default_value) < 0) {
+ return NULL;
+ }
+
+ ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
+ if (is_present.get() == NULL) {
+ return NULL;
+ }
+
+ if (PyObject_IsTrue(is_present.get())) {
+ return MapReflectionFriend::MessageMapGetItem(self, key);
+ } else {
+ if (default_value != NULL) {
+ Py_INCREF(default_value);
+ return default_value;
+ } else {
+ Py_RETURN_NONE;
+ }
+ }
+}
+
+static void MessageMapDealloc(PyObject* _self) {
+ MessageMapContainer* self = GetMessageMap(_self);
+ self->owner.reset();
+ Py_DECREF(self->message_dict);
+ Py_TYPE(_self)->tp_free(_self);
+}
+
+static PyMappingMethods MessageMapMappingMethods = {
+ MapReflectionFriend::Length, // mp_length
+ MapReflectionFriend::MessageMapGetItem, // mp_subscript
+ MapReflectionFriend::MessageMapSetItem, // mp_ass_subscript
+};
+
+static PyMethodDef MessageMapMethods[] = {
+ { "__contains__", (PyCFunction)MapReflectionFriend::Contains, METH_O,
+ "Tests whether the map contains this element."},
+ { "clear", (PyCFunction)Clear, METH_NOARGS,
+ "Removes all elements from the map."},
+ { "get", MessageMapGet, METH_VARARGS,
+ "Gets the value for the given key if present, or otherwise a default" },
+ { "get_or_create", MapReflectionFriend::MessageMapGetItem, METH_O,
+ "Alias for getitem, useful to make explicit that the map is mutated." },
+ /*
+ { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
+ "Makes a deep copy of the class." },
+ { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
+ "Outputs picklable representation of the repeated field." },
+ */
+ {NULL, NULL},
+};
+
+PyTypeObject MessageMapContainer_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ FULL_MODULE_NAME ".MessageMapContainer", // tp_name
+ sizeof(MessageMapContainer), // tp_basicsize
+ 0, // tp_itemsize
+ MessageMapDealloc, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ 0, // tp_as_sequence
+ &MessageMapMappingMethods, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT, // tp_flags
+ "A map container for message", // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ 0, // tp_richcompare
+ 0, // tp_weaklistoffset
+ MapReflectionFriend::GetIterator, // tp_iter
+ 0, // tp_iternext
+ MessageMapMethods, // tp_methods
+ 0, // tp_members
+ 0, // tp_getset
+ 0, // tp_base
+ 0, // tp_dict
+ 0, // tp_descr_get
+ 0, // tp_descr_set
+ 0, // tp_dictoffset
+ 0, // tp_init
+};
+
+// MapIterator /////////////////////////////////////////////////////////////////
+
+static MapIterator* GetIter(PyObject* obj) {
+ return reinterpret_cast<MapIterator*>(obj);
+}
+
+PyObject* MapReflectionFriend::GetIterator(PyObject *_self) {
+ MapContainer* self = GetMap(_self);
+
+ ScopedPyObjectPtr obj(PyType_GenericAlloc(&MapIterator_Type, 0));
+ if (obj == NULL) {
+ return PyErr_Format(PyExc_KeyError, "Could not allocate iterator");
+ }
+
+ MapIterator* iter = GetIter(obj.get());
+
+ Py_INCREF(self);
+ iter->container = self;
+ iter->version = self->version;
+ iter->owner = self->owner;
+
+ if (MapReflectionFriend::Length(_self) > 0) {
+ Message* message = self->GetMutableMessage();
+ const Reflection* reflection = message->GetReflection();
+
+ iter->iter.reset(new ::google::protobuf::MapIterator(
+ reflection->MapBegin(message, self->parent_field_descriptor)));
+ }
+
+ return obj.release();
+}
+
+PyObject* MapReflectionFriend::IterNext(PyObject* _self) {
+ MapIterator* self = GetIter(_self);
+
+ // This won't catch mutations to the map performed by MergeFrom(); no easy way
+ // to address that.
+ if (self->version != self->container->version) {
+ return PyErr_Format(PyExc_RuntimeError,
+ "Map modified during iteration.");
+ }
+
+ if (self->iter.get() == NULL) {
+ return NULL;
+ }
+
+ Message* message = self->container->GetMutableMessage();
+ const Reflection* reflection = message->GetReflection();
+
+ if (*self->iter ==
+ reflection->MapEnd(message, self->container->parent_field_descriptor)) {
+ return NULL;
+ }
+
+ PyObject* ret = MapKeyToPython(self->container->key_field_descriptor,
+ self->iter->GetKey());
+
+ ++(*self->iter);
+
+ return ret;
+}
+
+static void DeallocMapIterator(PyObject* _self) {
+ MapIterator* self = GetIter(_self);
+ self->iter.reset();
+ self->owner.reset();
+ Py_XDECREF(self->container);
+ Py_TYPE(_self)->tp_free(_self);
+}
+
+PyTypeObject MapIterator_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ FULL_MODULE_NAME ".MapIterator", // tp_name
+ sizeof(MapIterator), // tp_basicsize
+ 0, // tp_itemsize
+ DeallocMapIterator, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ 0, // tp_as_sequence
+ 0, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT, // tp_flags
+ "A scalar map iterator", // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ 0, // tp_richcompare
+ 0, // tp_weaklistoffset
+ PyObject_SelfIter, // tp_iter
+ MapReflectionFriend::IterNext, // tp_iternext
+ 0, // tp_methods
+ 0, // tp_members
+ 0, // tp_getset
+ 0, // tp_base
+ 0, // tp_dict
+ 0, // tp_descr_get
+ 0, // tp_descr_set
+ 0, // tp_dictoffset
+ 0, // tp_init
+};
+
+} // namespace python
+} // namespace protobuf
+} // namespace google
diff --git a/python/google/protobuf/pyext/message_map_container.h b/python/google/protobuf/pyext/map_container.h
index 4f6cb26a..2de61187 100644
--- a/python/google/protobuf/pyext/message_map_container.h
+++ b/python/google/protobuf/pyext/map_container.h
@@ -28,8 +28,8 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_MAP_CONTAINER_H__
-#define GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_MAP_CONTAINER_H__
+#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_MAP_CONTAINER_H__
+#define GOOGLE_PROTOBUF_PYTHON_CPP_MAP_CONTAINER_H__
#include <Python.h>
@@ -39,6 +39,7 @@
#endif
#include <google/protobuf/descriptor.h>
+#include <google/protobuf/message.h>
namespace google {
namespace protobuf {
@@ -55,18 +56,23 @@ namespace python {
struct CMessage;
-struct MessageMapContainer {
+// This struct is used directly for ScalarMap, and is the base class of
+// MessageMapContainer, which is used for MessageMap.
+struct MapContainer {
PyObject_HEAD;
// This is the top-level C++ Message object that owns the whole
- // proto tree. Every Python MessageMapContainer holds a
+ // proto tree. Every Python MapContainer holds a
// reference to it in order to keep it alive as long as there's a
// Python object that references any part of the tree.
shared_ptr<Message> owner;
// Pointer to the C++ Message that contains this container. The
- // MessageMapContainer does not own this pointer.
- Message* message;
+ // MapContainer does not own this pointer.
+ const Message* message;
+
+ // Use to get a mutable message when necessary.
+ Message* GetMutableMessage();
// Weak reference to a parent CMessage object (i.e. may be NULL.)
//
@@ -82,45 +88,46 @@ struct MessageMapContainer {
const FieldDescriptor* key_field_descriptor;
const FieldDescriptor* value_field_descriptor;
+ // We bump this whenever we perform a mutation, to invalidate existing
+ // iterators.
+ uint64 version;
+
+ // Releases the messages in the container to a new message.
+ //
+ // Returns 0 on success, -1 on failure.
+ int Release();
+
+ // Set the owner field of self and any children of self.
+ void SetOwner(const shared_ptr<Message>& new_owner) {
+ owner = new_owner;
+ }
+};
+
+struct MessageMapContainer : public MapContainer {
// A callable that is used to create new child messages.
PyObject* subclass_init;
// A dict mapping Message* -> CMessage.
PyObject* message_dict;
-
- // We bump this whenever we perform a mutation, to invalidate existing
- // iterators.
- uint64 version;
};
-#if PY_MAJOR_VERSION >= 3
- extern PyObject *MessageMapContainer_Type;
- extern PyType_Spec MessageMapContainer_Type_spec;
-#else
- extern PyTypeObject MessageMapContainer_Type;
-#endif
-extern PyTypeObject MessageMapIterator_Type;
+extern PyTypeObject ScalarMapContainer_Type;
+extern PyTypeObject MessageMapContainer_Type;
+extern PyTypeObject MapIterator_Type; // Both map types use the same iterator.
-namespace message_map_container {
-
-// Builds a MessageMapContainer object, from a parent message and a
+// Builds a MapContainer object, from a parent message and a
// field descriptor.
-extern PyObject* NewContainer(CMessage* parent,
- const FieldDescriptor* parent_field_descriptor,
- PyObject* concrete_class);
-
-// Releases the messages in the container to a new message.
-//
-// Returns 0 on success, -1 on failure.
-int Release(MessageMapContainer* self);
+extern PyObject* NewScalarMapContainer(
+ CMessage* parent, const FieldDescriptor* parent_field_descriptor);
-// Set the owner field of self and any children of self.
-void SetOwner(MessageMapContainer* self,
- const shared_ptr<Message>& new_owner);
+// Builds a MessageMap object, from a parent message and a
+// field descriptor.
+extern PyObject* NewMessageMapContainer(
+ CMessage* parent, const FieldDescriptor* parent_field_descriptor,
+ PyObject* concrete_class);
-} // namespace message_map_container
} // namespace python
} // namespace protobuf
} // namespace google
-#endif // GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_MAP_CONTAINER_H__
+#endif // GOOGLE_PROTOBUF_PYTHON_CPP_MAP_CONTAINER_H__
diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc
index 72f51ec1..863cde01 100644
--- a/python/google/protobuf/pyext/message.cc
+++ b/python/google/protobuf/pyext/message.cc
@@ -33,6 +33,7 @@
#include <google/protobuf/pyext/message.h>
+#include <map>
#include <memory>
#ifndef _SHARED_PTR_H
#include <google/protobuf/stubs/shared_ptr.h>
@@ -61,8 +62,7 @@
#include <google/protobuf/pyext/extension_dict.h>
#include <google/protobuf/pyext/repeated_composite_container.h>
#include <google/protobuf/pyext/repeated_scalar_container.h>
-#include <google/protobuf/pyext/message_map_container.h>
-#include <google/protobuf/pyext/scalar_map_container.h>
+#include <google/protobuf/pyext/map_container.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
#include <google/protobuf/stubs/strutil.h>
@@ -96,6 +96,7 @@ static PyObject* k_extensions_by_number;
PyObject* EnumTypeWrapper_class;
static PyObject* PythonMessage_class;
static PyObject* kEmptyWeakref;
+static PyObject* WKT_classes = NULL;
// Defines the Metaclass of all Message classes.
// It allows us to cache some C++ pointers in the class object itself, they are
@@ -274,8 +275,32 @@ static PyObject* New(PyTypeObject* type,
// Build the arguments to the base metaclass.
// We change the __bases__ classes.
- ScopedPyObjectPtr new_args(Py_BuildValue(
- "s(OO)O", name, &CMessage_Type, PythonMessage_class, dict));
+ ScopedPyObjectPtr new_args;
+ const Descriptor* message_descriptor =
+ PyMessageDescriptor_AsDescriptor(py_descriptor);
+ if (message_descriptor == NULL) {
+ return NULL;
+ }
+
+ if (WKT_classes == NULL) {
+ ScopedPyObjectPtr well_known_types(PyImport_ImportModule(
+ "google.protobuf.internal.well_known_types"));
+ GOOGLE_DCHECK(well_known_types != NULL);
+
+ WKT_classes = PyObject_GetAttrString(well_known_types.get(), "WKTBASES");
+ GOOGLE_DCHECK(WKT_classes != NULL);
+ }
+
+ PyObject* well_known_class = PyDict_GetItemString(
+ WKT_classes, message_descriptor->full_name().c_str());
+ if (well_known_class == NULL) {
+ new_args.reset(Py_BuildValue("s(OO)O", name, &CMessage_Type,
+ PythonMessage_class, dict));
+ } else {
+ new_args.reset(Py_BuildValue("s(OOO)O", name, &CMessage_Type,
+ PythonMessage_class, well_known_class, dict));
+ }
+
if (new_args == NULL) {
return NULL;
}
@@ -448,21 +473,9 @@ static int VisitCompositeField(const FieldDescriptor* descriptor,
if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
if (descriptor->is_map()) {
- const Descriptor* entry_type = descriptor->message_type();
- const FieldDescriptor* value_type =
- entry_type->FindFieldByName("value");
- if (value_type->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
- MessageMapContainer* container =
- reinterpret_cast<MessageMapContainer*>(child);
- if (visitor.VisitMessageMapContainer(container) == -1) {
- return -1;
- }
- } else {
- ScalarMapContainer* container =
- reinterpret_cast<ScalarMapContainer*>(child);
- if (visitor.VisitScalarMapContainer(container) == -1) {
- return -1;
- }
+ MapContainer* container = reinterpret_cast<MapContainer*>(child);
+ if (visitor.VisitMapContainer(container) == -1) {
+ return -1;
}
} else {
RepeatedCompositeContainer* container =
@@ -579,12 +592,14 @@ bool CheckAndGetInteger(
if (PyObject_RichCompareBool(min, arg, Py_LE) != 1 ||
PyObject_RichCompareBool(max, arg, Py_GE) != 1) {
#endif
- PyObject *s = PyObject_Str(arg);
- if (s) {
- PyErr_Format(PyExc_ValueError,
- "Value out of range: %s",
- PyString_AsString(s));
- Py_DECREF(s);
+ if (!PyErr_Occurred()) {
+ PyObject *s = PyObject_Str(arg);
+ if (s) {
+ PyErr_Format(PyExc_ValueError,
+ "Value out of range: %s",
+ PyString_AsString(s));
+ Py_DECREF(s);
+ }
}
return false;
}
@@ -642,38 +657,51 @@ bool CheckAndGetBool(PyObject* arg, bool* value) {
return true;
}
-bool CheckAndSetString(
- PyObject* arg, Message* message,
- const FieldDescriptor* descriptor,
- const Reflection* reflection,
- bool append,
- int index) {
+// Checks whether the given object (which must be "bytes" or "unicode") contains
+// valid UTF-8.
+bool IsValidUTF8(PyObject* obj) {
+ if (PyBytes_Check(obj)) {
+ PyObject* unicode = PyUnicode_FromEncodedObject(obj, "utf-8", NULL);
+
+ // Clear the error indicator; we report our own error when desired.
+ PyErr_Clear();
+
+ if (unicode) {
+ Py_DECREF(unicode);
+ return true;
+ } else {
+ return false;
+ }
+ } else {
+ // Unicode object, known to be valid UTF-8.
+ return true;
+ }
+}
+
+bool AllowInvalidUTF8(const FieldDescriptor* field) { return false; }
+
+PyObject* CheckString(PyObject* arg, const FieldDescriptor* descriptor) {
GOOGLE_DCHECK(descriptor->type() == FieldDescriptor::TYPE_STRING ||
descriptor->type() == FieldDescriptor::TYPE_BYTES);
if (descriptor->type() == FieldDescriptor::TYPE_STRING) {
if (!PyBytes_Check(arg) && !PyUnicode_Check(arg)) {
FormatTypeError(arg, "bytes, unicode");
- return false;
+ return NULL;
}
- if (PyBytes_Check(arg)) {
- PyObject* unicode = PyUnicode_FromEncodedObject(arg, "utf-8", NULL);
- if (unicode == NULL) {
- PyObject* repr = PyObject_Repr(arg);
- PyErr_Format(PyExc_ValueError,
- "%s has type str, but isn't valid UTF-8 "
- "encoding. Non-UTF-8 strings must be converted to "
- "unicode objects before being added.",
- PyString_AsString(repr));
- Py_DECREF(repr);
- return false;
- } else {
- Py_DECREF(unicode);
- }
+ if (!IsValidUTF8(arg) && !AllowInvalidUTF8(descriptor)) {
+ PyObject* repr = PyObject_Repr(arg);
+ PyErr_Format(PyExc_ValueError,
+ "%s has type str, but isn't valid UTF-8 "
+ "encoding. Non-UTF-8 strings must be converted to "
+ "unicode objects before being added.",
+ PyString_AsString(repr));
+ Py_DECREF(repr);
+ return NULL;
}
} else if (!PyBytes_Check(arg)) {
FormatTypeError(arg, "bytes");
- return false;
+ return NULL;
}
PyObject* encoded_string = NULL;
@@ -691,14 +719,24 @@ bool CheckAndSetString(
Py_INCREF(encoded_string);
}
- if (encoded_string == NULL) {
+ return encoded_string;
+}
+
+bool CheckAndSetString(
+ PyObject* arg, Message* message,
+ const FieldDescriptor* descriptor,
+ const Reflection* reflection,
+ bool append,
+ int index) {
+ ScopedPyObjectPtr encoded_string(CheckString(arg, descriptor));
+
+ if (encoded_string.get() == NULL) {
return false;
}
char* value;
Py_ssize_t value_len;
- if (PyBytes_AsStringAndSize(encoded_string, &value, &value_len) < 0) {
- Py_DECREF(encoded_string);
+ if (PyBytes_AsStringAndSize(encoded_string.get(), &value, &value_len) < 0) {
return false;
}
@@ -710,7 +748,6 @@ bool CheckAndSetString(
} else {
reflection->SetRepeatedString(message, descriptor, index, value_string);
}
- Py_DECREF(encoded_string);
return true;
}
@@ -823,12 +860,7 @@ struct FixupMessageReference : public ChildVisitor {
return 0;
}
- int VisitScalarMapContainer(ScalarMapContainer* container) {
- container->message = message_;
- return 0;
- }
-
- int VisitMessageMapContainer(MessageMapContainer* container) {
+ int VisitMapContainer(MapContainer* container) {
container->message = message_;
return 0;
}
@@ -870,9 +902,8 @@ int AssureWritable(CMessage* self) {
// When a CMessage is made writable its Message pointer is updated
// to point to a new mutable Message. When that happens we need to
// update any references to the old, read-only CMessage. There are
- // five places such references occur: RepeatedScalarContainer,
- // RepeatedCompositeContainer, ScalarMapContainer, MessageMapContainer,
- // and ExtensionDict.
+ // four places such references occur: RepeatedScalarContainer,
+ // RepeatedCompositeContainer, MapContainer, and ExtensionDict.
if (self->extensions != NULL)
self->extensions->message = self->message;
if (ForEachCompositeField(self, FixupMessageReference(self->message)) == -1)
@@ -1054,7 +1085,8 @@ int InitAttributes(CMessage* self, PyObject* kwargs) {
}
const FieldDescriptor* descriptor = GetFieldDescriptor(self, name);
if (descriptor == NULL) {
- PyErr_Format(PyExc_ValueError, "Protocol message has no \"%s\" field.",
+ PyErr_Format(PyExc_ValueError, "Protocol message %s has no \"%s\" field.",
+ self->message->GetDescriptor()->name().c_str(),
PyString_AsString(name));
return -1;
}
@@ -1203,18 +1235,6 @@ CMessage* NewEmptyMessage(PyObject* type, const Descriptor *descriptor) {
self->composite_fields = NULL;
- // If there are extension_ranges, the message is "extendable". Allocate a
- // dictionary to store the extension fields.
- if (descriptor->extension_range_count() > 0) {
- // TODO(amauryfa): Delay the construction of this dict until extensions are
- // really used on the object.
- ExtensionDict* extension_dict = extension_dict::NewExtensionDict(self);
- if (extension_dict == NULL) {
- return NULL;
- }
- self->extensions = extension_dict;
- }
-
return self;
}
@@ -1285,12 +1305,7 @@ struct ClearWeakReferences : public ChildVisitor {
return 0;
}
- int VisitScalarMapContainer(ScalarMapContainer* container) {
- container->parent = NULL;
- return 0;
- }
-
- int VisitMessageMapContainer(MessageMapContainer* container) {
+ int VisitMapContainer(MapContainer* container) {
container->parent = NULL;
return 0;
}
@@ -1305,6 +1320,9 @@ struct ClearWeakReferences : public ChildVisitor {
static void Dealloc(CMessage* self) {
// Null out all weak references from children to this message.
GOOGLE_CHECK_EQ(0, ForEachCompositeField(self, ClearWeakReferences()));
+ if (self->extensions) {
+ self->extensions->parent = NULL;
+ }
Py_CLEAR(self->extensions);
Py_CLEAR(self->composite_fields);
@@ -1466,20 +1484,27 @@ PyObject* HasField(CMessage* self, PyObject* arg) {
Py_RETURN_FALSE;
}
-PyObject* ClearExtension(CMessage* self, PyObject* arg) {
+PyObject* ClearExtension(CMessage* self, PyObject* extension) {
if (self->extensions != NULL) {
- return extension_dict::ClearExtension(self->extensions, arg);
+ return extension_dict::ClearExtension(self->extensions, extension);
+ } else {
+ const FieldDescriptor* descriptor = GetExtensionDescriptor(extension);
+ if (descriptor == NULL) {
+ return NULL;
+ }
+ if (ScopedPyObjectPtr(ClearFieldByDescriptor(self, descriptor)) == NULL) {
+ return NULL;
+ }
}
- PyErr_SetString(PyExc_TypeError, "Message is not extendable");
- return NULL;
+ Py_RETURN_NONE;
}
-PyObject* HasExtension(CMessage* self, PyObject* arg) {
- if (self->extensions != NULL) {
- return extension_dict::HasExtension(self->extensions, arg);
+PyObject* HasExtension(CMessage* self, PyObject* extension) {
+ const FieldDescriptor* descriptor = GetExtensionDescriptor(extension);
+ if (descriptor == NULL) {
+ return NULL;
}
- PyErr_SetString(PyExc_TypeError, "Message is not extendable");
- return NULL;
+ return HasFieldByDescriptor(self, descriptor);
}
// ---------------------------------------------------------------------
@@ -1529,13 +1554,8 @@ struct SetOwnerVisitor : public ChildVisitor {
return 0;
}
- int VisitScalarMapContainer(ScalarMapContainer* container) {
- scalar_map_container::SetOwner(container, new_owner_);
- return 0;
- }
-
- int VisitMessageMapContainer(MessageMapContainer* container) {
- message_map_container::SetOwner(container, new_owner_);
+ int VisitMapContainer(MapContainer* container) {
+ container->SetOwner(new_owner_);
return 0;
}
@@ -1608,14 +1628,8 @@ struct ReleaseChild : public ChildVisitor {
reinterpret_cast<RepeatedScalarContainer*>(container));
}
- int VisitScalarMapContainer(ScalarMapContainer* container) {
- return scalar_map_container::Release(
- reinterpret_cast<ScalarMapContainer*>(container));
- }
-
- int VisitMessageMapContainer(MessageMapContainer* container) {
- return message_map_container::Release(
- reinterpret_cast<MessageMapContainer*>(container));
+ int VisitMapContainer(MapContainer* container) {
+ return reinterpret_cast<MapContainer*>(container)->Release();
}
int VisitCMessage(CMessage* cmessage,
@@ -1707,17 +1721,7 @@ PyObject* Clear(CMessage* self) {
AssureWritable(self);
if (ForEachCompositeField(self, ReleaseChild(self)) == -1)
return NULL;
-
- // The old ExtensionDict still aliases this CMessage, but all its
- // fields have been released.
- if (self->extensions != NULL) {
- Py_CLEAR(self->extensions);
- ExtensionDict* extension_dict = extension_dict::NewExtensionDict(self);
- if (extension_dict == NULL) {
- return NULL;
- }
- self->extensions = extension_dict;
- }
+ Py_CLEAR(self->extensions);
if (self->composite_fields) {
PyDict_Clear(self->composite_fields);
}
@@ -1997,7 +2001,6 @@ static PyObject* RegisterExtension(PyObject* cls,
if (descriptor->is_extension() &&
descriptor->containing_type()->options().message_set_wire_format() &&
descriptor->type() == FieldDescriptor::TYPE_MESSAGE &&
- descriptor->message_type() == descriptor->extension_scope() &&
descriptor->label() == FieldDescriptor::LABEL_OPTIONAL) {
ScopedPyObjectPtr message_name(PyString_FromStringAndSize(
descriptor->message_type()->full_name().c_str(),
@@ -2042,6 +2045,8 @@ static PyObject* WhichOneof(CMessage* self, PyObject* arg) {
}
}
+static PyObject* GetExtensionDict(CMessage* self, void *closure);
+
static PyObject* ListFields(CMessage* self) {
vector<const FieldDescriptor*> fields;
self->message->GetReflection()->ListFields(*self->message, &fields);
@@ -2079,12 +2084,13 @@ static PyObject* ListFields(CMessage* self) {
PyErr_Clear();
continue;
}
- PyObject* extensions = reinterpret_cast<PyObject*>(self->extensions);
+ ScopedPyObjectPtr extensions(GetExtensionDict(self, NULL));
if (extensions == NULL) {
return NULL;
}
// 'extension' reference later stolen by PyTuple_SET_ITEM.
- PyObject* extension = PyObject_GetItem(extensions, extension_field.get());
+ PyObject* extension = PyObject_GetItem(
+ extensions.get(), extension_field.get());
if (extension == NULL) {
return NULL;
}
@@ -2493,9 +2499,31 @@ PyObject* _CheckCalledFromGeneratedFile(PyObject* unused,
Py_RETURN_NONE;
}
-static PyMemberDef Members[] = {
- {"Extensions", T_OBJECT_EX, offsetof(CMessage, extensions), 0,
- "Extension dict"},
+static PyObject* GetExtensionDict(CMessage* self, void *closure) {
+ if (self->extensions) {
+ Py_INCREF(self->extensions);
+ return reinterpret_cast<PyObject*>(self->extensions);
+ }
+
+ // If there are extension_ranges, the message is "extendable". Allocate a
+ // dictionary to store the extension fields.
+ const Descriptor* descriptor = GetMessageDescriptor(Py_TYPE(self));
+ if (descriptor->extension_range_count() > 0) {
+ ExtensionDict* extension_dict = extension_dict::NewExtensionDict(self);
+ if (extension_dict == NULL) {
+ return NULL;
+ }
+ self->extensions = extension_dict;
+ Py_INCREF(self->extensions);
+ return reinterpret_cast<PyObject*>(self->extensions);
+ }
+
+ PyErr_SetNone(PyExc_AttributeError);
+ return NULL;
+}
+
+static PyGetSetDef Getters[] = {
+ {"Extensions", (getter)GetExtensionDict, NULL, "Extension dict"},
{NULL}
};
@@ -2592,10 +2620,10 @@ PyObject* GetAttr(CMessage* self, PyObject* name) {
if (value_class == NULL) {
return NULL;
}
- py_container = message_map_container::NewContainer(self, field_descriptor,
- value_class);
+ py_container =
+ NewMessageMapContainer(self, field_descriptor, value_class);
} else {
- py_container = scalar_map_container::NewContainer(self, field_descriptor);
+ py_container = NewScalarMapContainer(self, field_descriptor);
}
if (py_container == NULL) {
return NULL;
@@ -2672,7 +2700,10 @@ int SetAttr(CMessage* self, PyObject* name, PyObject* value) {
}
}
- PyErr_Format(PyExc_AttributeError, "Assignment not allowed");
+ PyErr_Format(PyExc_AttributeError,
+ "Assignment not allowed "
+ "(no field \"%s\"in protocol message object).",
+ PyString_AsString(name));
return -1;
}
@@ -2707,8 +2738,8 @@ PyTypeObject CMessage_Type = {
0, // tp_iter
0, // tp_iternext
cmessage::Methods, // tp_methods
- cmessage::Members, // tp_members
- 0, // tp_getset
+ 0, // tp_members
+ cmessage::Getters, // tp_getset
0, // tp_base
0, // tp_dict
0, // tp_descr_get
@@ -2910,12 +2941,12 @@ bool InitProto2MessageModule(PyObject *m) {
reinterpret_cast<PyObject*>(&ScalarMapContainer_Type));
#endif
- if (PyType_Ready(&ScalarMapIterator_Type) < 0) {
+ if (PyType_Ready(&MapIterator_Type) < 0) {
return false;
}
- PyModule_AddObject(m, "ScalarMapIterator",
- reinterpret_cast<PyObject*>(&ScalarMapIterator_Type));
+ PyModule_AddObject(m, "MapIterator",
+ reinterpret_cast<PyObject*>(&MapIterator_Type));
#if PY_MAJOR_VERSION >= 3
@@ -2934,13 +2965,6 @@ bool InitProto2MessageModule(PyObject *m) {
PyModule_AddObject(m, "MessageMapContainer",
reinterpret_cast<PyObject*>(&MessageMapContainer_Type));
#endif
-
- if (PyType_Ready(&MessageMapIterator_Type) < 0) {
- return false;
- }
-
- PyModule_AddObject(m, "MessageMapIterator",
- reinterpret_cast<PyObject*>(&MessageMapIterator_Type));
}
if (PyType_Ready(&ExtensionDict_Type) < 0) {
@@ -2957,6 +2981,9 @@ bool InitProto2MessageModule(PyObject *m) {
PyModule_AddObject(m, "default_pool",
reinterpret_cast<PyObject*>(GetDefaultDescriptorPool()));
+ PyModule_AddObject(m, "DescriptorPool", reinterpret_cast<PyObject*>(
+ &PyDescriptorPool_Type));
+
// This implementation provides full Descriptor types, we advertise it so that
// descriptor.py can use them in replacement of the Python classes.
PyModule_AddIntConstant(m, "_USE_C_DESCRIPTORS", 1);
diff --git a/python/google/protobuf/pyext/message.h b/python/google/protobuf/pyext/message.h
index 94de4551..cc0012e9 100644
--- a/python/google/protobuf/pyext/message.h
+++ b/python/google/protobuf/pyext/message.h
@@ -307,6 +307,7 @@ bool CheckAndGetInteger(
bool CheckAndGetDouble(PyObject* arg, double* value);
bool CheckAndGetFloat(PyObject* arg, float* value);
bool CheckAndGetBool(PyObject* arg, bool* value);
+PyObject* CheckString(PyObject* arg, const FieldDescriptor* descriptor);
bool CheckAndSetString(
PyObject* arg, Message* message,
const FieldDescriptor* descriptor,
diff --git a/python/google/protobuf/pyext/message_map_container.cc b/python/google/protobuf/pyext/message_map_container.cc
deleted file mode 100644
index 8902fa00..00000000
--- a/python/google/protobuf/pyext/message_map_container.cc
+++ /dev/null
@@ -1,569 +0,0 @@
-// Protocol Buffers - Google's data interchange format
-// Copyright 2008 Google Inc. All rights reserved.
-// https://developers.google.com/protocol-buffers/
-//
-// Redistribution and use in source and binary forms, with or without
-// modification, are permitted provided that the following conditions are
-// met:
-//
-// * Redistributions of source code must retain the above copyright
-// notice, this list of conditions and the following disclaimer.
-// * Redistributions in binary form must reproduce the above
-// copyright notice, this list of conditions and the following disclaimer
-// in the documentation and/or other materials provided with the
-// distribution.
-// * Neither the name of Google Inc. nor the names of its
-// contributors may be used to endorse or promote products derived from
-// this software without specific prior written permission.
-//
-// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-// Author: haberman@google.com (Josh Haberman)
-
-#include <google/protobuf/pyext/message_map_container.h>
-
-#include <google/protobuf/stubs/logging.h>
-#include <google/protobuf/stubs/common.h>
-#include <google/protobuf/message.h>
-#include <google/protobuf/pyext/message.h>
-#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
-
-namespace google {
-namespace protobuf {
-namespace python {
-
-struct MessageMapIterator {
- PyObject_HEAD;
-
- // This dict contains the full contents of what we want to iterate over.
- // There's no way to avoid building this, because the list representation
- // (which is canonical) can contain duplicate keys. So at the very least we
- // need a set that lets us skip duplicate keys. And at the point that we're
- // doing that, we might as well just build the actual dict we're iterating
- // over and use dict's built-in iterator.
- PyObject* dict;
-
- // An iterator on dict.
- PyObject* iter;
-
- // A pointer back to the container, so we can notice changes to the version.
- MessageMapContainer* container;
-
- // The version of the map when we took the iterator to it.
- //
- // We store this so that if the map is modified during iteration we can throw
- // an error.
- uint64 version;
-};
-
-static MessageMapIterator* GetIter(PyObject* obj) {
- return reinterpret_cast<MessageMapIterator*>(obj);
-}
-
-namespace message_map_container {
-
-static MessageMapContainer* GetMap(PyObject* obj) {
- return reinterpret_cast<MessageMapContainer*>(obj);
-}
-
-// The private constructor of MessageMapContainer objects.
-PyObject* NewContainer(CMessage* parent,
- const google::protobuf::FieldDescriptor* parent_field_descriptor,
- PyObject* concrete_class) {
- if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
- return NULL;
- }
-
-#if PY_MAJOR_VERSION >= 3
- PyObject* obj = PyType_GenericAlloc(
- reinterpret_cast<PyTypeObject *>(MessageMapContainer_Type), 0);
-#else
- PyObject* obj = PyType_GenericAlloc(&MessageMapContainer_Type, 0);
-#endif
- if (obj == NULL) {
- return PyErr_Format(PyExc_RuntimeError,
- "Could not allocate new container.");
- }
-
- MessageMapContainer* self = GetMap(obj);
-
- self->message = parent->message;
- self->parent = parent;
- self->parent_field_descriptor = parent_field_descriptor;
- self->owner = parent->owner;
- self->version = 0;
-
- self->key_field_descriptor =
- parent_field_descriptor->message_type()->FindFieldByName("key");
- self->value_field_descriptor =
- parent_field_descriptor->message_type()->FindFieldByName("value");
-
- self->message_dict = PyDict_New();
- if (self->message_dict == NULL) {
- return PyErr_Format(PyExc_RuntimeError,
- "Could not allocate message dict.");
- }
-
- Py_INCREF(concrete_class);
- self->subclass_init = concrete_class;
-
- if (self->key_field_descriptor == NULL ||
- self->value_field_descriptor == NULL) {
- Py_DECREF(obj);
- return PyErr_Format(PyExc_KeyError,
- "Map entry descriptor did not have key/value fields");
- }
-
- return obj;
-}
-
-// Initializes the underlying Message object of "to" so it becomes a new parent
-// repeated scalar, and copies all the values from "from" to it. A child scalar
-// container can be released by passing it as both from and to (e.g. making it
-// the recipient of the new parent message and copying the values from itself).
-static int InitializeAndCopyToParentContainer(
- MessageMapContainer* from,
- MessageMapContainer* to) {
- // For now we require from == to, re-evaluate if we want to support deep copy
- // as in repeated_composite_container.cc.
- GOOGLE_DCHECK(from == to);
- Message* old_message = from->message;
- Message* new_message = old_message->New();
- to->parent = NULL;
- to->parent_field_descriptor = from->parent_field_descriptor;
- to->message = new_message;
- to->owner.reset(new_message);
-
- vector<const FieldDescriptor*> fields;
- fields.push_back(from->parent_field_descriptor);
- old_message->GetReflection()->SwapFields(old_message, new_message, fields);
- return 0;
-}
-
-static PyObject* GetCMessage(MessageMapContainer* self, Message* entry) {
- // Get or create the CMessage object corresponding to this message.
- Message* message = entry->GetReflection()->MutableMessage(
- entry, self->value_field_descriptor);
- ScopedPyObjectPtr key(PyLong_FromVoidPtr(message));
- PyObject* ret = PyDict_GetItem(self->message_dict, key.get());
-
- if (ret == NULL) {
- CMessage* cmsg = cmessage::NewEmptyMessage(self->subclass_init,
- message->GetDescriptor());
- ret = reinterpret_cast<PyObject*>(cmsg);
-
- if (cmsg == NULL) {
- return NULL;
- }
- cmsg->owner = self->owner;
- cmsg->message = message;
- cmsg->parent = self->parent;
-
- if (PyDict_SetItem(self->message_dict, key.get(), ret) < 0) {
- Py_DECREF(ret);
- return NULL;
- }
- } else {
- Py_INCREF(ret);
- }
-
- return ret;
-}
-
-int Release(MessageMapContainer* self) {
- InitializeAndCopyToParentContainer(self, self);
- return 0;
-}
-
-void SetOwner(MessageMapContainer* self,
- const shared_ptr<Message>& new_owner) {
- self->owner = new_owner;
-}
-
-Py_ssize_t Length(PyObject* _self) {
- MessageMapContainer* self = GetMap(_self);
- google::protobuf::Message* message = self->message;
- return message->GetReflection()->FieldSize(*message,
- self->parent_field_descriptor);
-}
-
-int MapKeyMatches(MessageMapContainer* self, const Message* entry,
- PyObject* key) {
- // TODO(haberman): do we need more strict type checking?
- ScopedPyObjectPtr entry_key(
- cmessage::InternalGetScalar(entry, self->key_field_descriptor));
- int ret = PyObject_RichCompareBool(key, entry_key.get(), Py_EQ);
- return ret;
-}
-
-int SetItem(PyObject *_self, PyObject *key, PyObject *v) {
- if (v) {
- PyErr_Format(PyExc_ValueError,
- "Direct assignment of submessage not allowed");
- return -1;
- }
-
- // Now we know that this is a delete, not a set.
-
- MessageMapContainer* self = GetMap(_self);
- cmessage::AssureWritable(self->parent);
-
- Message* message = self->message;
- const Reflection* reflection = message->GetReflection();
- size_t size =
- reflection->FieldSize(*message, self->parent_field_descriptor);
-
- // Right now the Reflection API doesn't support map lookup, so we implement it
- // via linear search. We need to search from the end because the underlying
- // representation can have duplicates if a user calls MergeFrom(); the last
- // one needs to win.
- //
- // TODO(haberman): add lookup API to Reflection API.
- bool found = false;
- for (int i = size - 1; i >= 0; i--) {
- Message* entry = reflection->MutableRepeatedMessage(
- message, self->parent_field_descriptor, i);
- int matches = MapKeyMatches(self, entry, key);
- if (matches < 0) return -1;
- if (matches) {
- found = true;
- if (i != (int)size - 1) {
- reflection->SwapElements(message, self->parent_field_descriptor, i,
- size - 1);
- }
- reflection->RemoveLast(message, self->parent_field_descriptor);
-
- // Can't exit now, the repeated field representation of maps allows
- // duplicate keys, and we have to be sure to remove all of them.
- }
- }
-
- if (!found) {
- PyErr_Format(PyExc_KeyError, "Key not present in map");
- return -1;
- }
-
- self->version++;
-
- return 0;
-}
-
-PyObject* GetIterator(PyObject *_self) {
- MessageMapContainer* self = GetMap(_self);
-
- ScopedPyObjectPtr obj(PyType_GenericAlloc(&MessageMapIterator_Type, 0));
- if (obj == NULL) {
- return PyErr_Format(PyExc_KeyError, "Could not allocate iterator");
- }
-
- MessageMapIterator* iter = GetIter(obj.get());
-
- Py_INCREF(self);
- iter->container = self;
- iter->version = self->version;
- iter->dict = PyDict_New();
- if (iter->dict == NULL) {
- return PyErr_Format(PyExc_RuntimeError,
- "Could not allocate dict for iterator.");
- }
-
- // Build the entire map into a dict right now. Start from the beginning so
- // that later entries win in the case of duplicates.
- Message* message = self->message;
- const Reflection* reflection = message->GetReflection();
-
- // Right now the Reflection API doesn't support map lookup, so we implement it
- // via linear search. We need to search from the end because the underlying
- // representation can have duplicates if a user calls MergeFrom(); the last
- // one needs to win.
- //
- // TODO(haberman): add lookup API to Reflection API.
- size_t size =
- reflection->FieldSize(*message, self->parent_field_descriptor);
- for (int i = size - 1; i >= 0; i--) {
- Message* entry = reflection->MutableRepeatedMessage(
- message, self->parent_field_descriptor, i);
- ScopedPyObjectPtr key(
- cmessage::InternalGetScalar(entry, self->key_field_descriptor));
- if (PyDict_SetItem(iter->dict, key.get(), GetCMessage(self, entry)) < 0) {
- return PyErr_Format(PyExc_RuntimeError,
- "SetItem failed in iterator construction.");
- }
- }
-
- iter->iter = PyObject_GetIter(iter->dict);
-
- return obj.release();
-}
-
-PyObject* GetItem(PyObject* _self, PyObject* key) {
- MessageMapContainer* self = GetMap(_self);
- cmessage::AssureWritable(self->parent);
- Message* message = self->message;
- const Reflection* reflection = message->GetReflection();
-
- // Right now the Reflection API doesn't support map lookup, so we implement it
- // via linear search. We need to search from the end because the underlying
- // representation can have duplicates if a user calls MergeFrom(); the last
- // one needs to win.
- //
- // TODO(haberman): add lookup API to Reflection API.
- size_t size =
- reflection->FieldSize(*message, self->parent_field_descriptor);
- for (int i = size - 1; i >= 0; i--) {
- Message* entry = reflection->MutableRepeatedMessage(
- message, self->parent_field_descriptor, i);
- int matches = MapKeyMatches(self, entry, key);
- if (matches < 0) return NULL;
- if (matches) {
- return GetCMessage(self, entry);
- }
- }
-
- // Key is not already present; insert a new entry.
- Message* entry =
- reflection->AddMessage(message, self->parent_field_descriptor);
-
- self->version++;
-
- if (cmessage::InternalSetNonOneofScalar(entry, self->key_field_descriptor,
- key) < 0) {
- reflection->RemoveLast(message, self->parent_field_descriptor);
- return NULL;
- }
-
- return GetCMessage(self, entry);
-}
-
-PyObject* Contains(PyObject* _self, PyObject* key) {
- MessageMapContainer* self = GetMap(_self);
- Message* message = self->message;
- const Reflection* reflection = message->GetReflection();
-
- // Right now the Reflection API doesn't support map lookup, so we implement it
- // via linear search.
- //
- // TODO(haberman): add lookup API to Reflection API.
- int size =
- reflection->FieldSize(*message, self->parent_field_descriptor);
- for (int i = 0; i < size; i++) {
- Message* entry = reflection->MutableRepeatedMessage(
- message, self->parent_field_descriptor, i);
- int matches = MapKeyMatches(self, entry, key);
- if (matches < 0) return NULL;
- if (matches) {
- Py_RETURN_TRUE;
- }
- }
-
- Py_RETURN_FALSE;
-}
-
-PyObject* Clear(PyObject* _self) {
- MessageMapContainer* self = GetMap(_self);
- cmessage::AssureWritable(self->parent);
- Message* message = self->message;
- const Reflection* reflection = message->GetReflection();
-
- self->version++;
- reflection->ClearField(message, self->parent_field_descriptor);
-
- Py_RETURN_NONE;
-}
-
-PyObject* Get(PyObject* self, PyObject* args) {
- PyObject* key;
- PyObject* default_value = NULL;
- if (PyArg_ParseTuple(args, "O|O", &key, &default_value) < 0) {
- return NULL;
- }
-
- ScopedPyObjectPtr is_present(Contains(self, key));
- if (is_present.get() == NULL) {
- return NULL;
- }
-
- if (PyObject_IsTrue(is_present.get())) {
- return GetItem(self, key);
- } else {
- if (default_value != NULL) {
- Py_INCREF(default_value);
- return default_value;
- } else {
- Py_RETURN_NONE;
- }
- }
-}
-
-static void Dealloc(PyObject* _self) {
- MessageMapContainer* self = GetMap(_self);
- self->owner.reset();
- Py_DECREF(self->message_dict);
- Py_TYPE(_self)->tp_free(_self);
-}
-
-static PyMethodDef Methods[] = {
- { "__contains__", (PyCFunction)Contains, METH_O,
- "Tests whether the map contains this element."},
- { "clear", (PyCFunction)Clear, METH_NOARGS,
- "Removes all elements from the map."},
- { "get", Get, METH_VARARGS,
- "Gets the value for the given key if present, or otherwise a default" },
- { "get_or_create", GetItem, METH_O,
- "Alias for getitem, useful to make explicit that the map is mutated." },
- /*
- { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
- "Makes a deep copy of the class." },
- { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
- "Outputs picklable representation of the repeated field." },
- */
- {NULL, NULL},
-};
-
-} // namespace message_map_container
-
-namespace message_map_iterator {
-
-static void Dealloc(PyObject* _self) {
- MessageMapIterator* self = GetIter(_self);
- Py_DECREF(self->dict);
- Py_DECREF(self->iter);
- Py_DECREF(self->container);
- Py_TYPE(_self)->tp_free(_self);
-}
-
-PyObject* IterNext(PyObject* _self) {
- MessageMapIterator* self = GetIter(_self);
-
- // This won't catch mutations to the map performed by MergeFrom(); no easy way
- // to address that.
- if (self->version != self->container->version) {
- return PyErr_Format(PyExc_RuntimeError,
- "Map modified during iteration.");
- }
-
- return PyIter_Next(self->iter);
-}
-
-} // namespace message_map_iterator
-
-#if PY_MAJOR_VERSION >= 3
- static PyType_Slot MessageMapContainer_Type_slots[] = {
- {Py_tp_dealloc, (void *)message_map_container::Dealloc},
- {Py_mp_length, (void *)message_map_container::Length},
- {Py_mp_subscript, (void *)message_map_container::GetItem},
- {Py_mp_ass_subscript, (void *)message_map_container::SetItem},
- {Py_tp_methods, (void *)message_map_container::Methods},
- {Py_tp_iter, (void *)message_map_container::GetIterator},
- {0, 0}
- };
-
- PyType_Spec MessageMapContainer_Type_spec = {
- FULL_MODULE_NAME ".MessageMapContainer",
- sizeof(MessageMapContainer),
- 0,
- Py_TPFLAGS_DEFAULT,
- MessageMapContainer_Type_slots
- };
-
- PyObject *MessageMapContainer_Type;
-
-#else
- static PyMappingMethods MpMethods = {
- message_map_container::Length, // mp_length
- message_map_container::GetItem, // mp_subscript
- message_map_container::SetItem, // mp_ass_subscript
- };
-
- PyTypeObject MessageMapContainer_Type = {
- PyVarObject_HEAD_INIT(&PyType_Type, 0)
- FULL_MODULE_NAME ".MessageMapContainer", // tp_name
- sizeof(MessageMapContainer), // tp_basicsize
- 0, // tp_itemsize
- message_map_container::Dealloc, // tp_dealloc
- 0, // tp_print
- 0, // tp_getattr
- 0, // tp_setattr
- 0, // tp_compare
- 0, // tp_repr
- 0, // tp_as_number
- 0, // tp_as_sequence
- &MpMethods, // tp_as_mapping
- 0, // tp_hash
- 0, // tp_call
- 0, // tp_str
- 0, // tp_getattro
- 0, // tp_setattro
- 0, // tp_as_buffer
- Py_TPFLAGS_DEFAULT, // tp_flags
- "A map container for message", // tp_doc
- 0, // tp_traverse
- 0, // tp_clear
- 0, // tp_richcompare
- 0, // tp_weaklistoffset
- message_map_container::GetIterator, // tp_iter
- 0, // tp_iternext
- message_map_container::Methods, // tp_methods
- 0, // tp_members
- 0, // tp_getset
- 0, // tp_base
- 0, // tp_dict
- 0, // tp_descr_get
- 0, // tp_descr_set
- 0, // tp_dictoffset
- 0, // tp_init
- };
-#endif
-
-PyTypeObject MessageMapIterator_Type = {
- PyVarObject_HEAD_INIT(&PyType_Type, 0)
- FULL_MODULE_NAME ".MessageMapIterator", // tp_name
- sizeof(MessageMapIterator), // tp_basicsize
- 0, // tp_itemsize
- message_map_iterator::Dealloc, // tp_dealloc
- 0, // tp_print
- 0, // tp_getattr
- 0, // tp_setattr
- 0, // tp_compare
- 0, // tp_repr
- 0, // tp_as_number
- 0, // tp_as_sequence
- 0, // tp_as_mapping
- 0, // tp_hash
- 0, // tp_call
- 0, // tp_str
- 0, // tp_getattro
- 0, // tp_setattro
- 0, // tp_as_buffer
- Py_TPFLAGS_DEFAULT, // tp_flags
- "A scalar map iterator", // tp_doc
- 0, // tp_traverse
- 0, // tp_clear
- 0, // tp_richcompare
- 0, // tp_weaklistoffset
- PyObject_SelfIter, // tp_iter
- message_map_iterator::IterNext, // tp_iternext
- 0, // tp_methods
- 0, // tp_members
- 0, // tp_getset
- 0, // tp_base
- 0, // tp_dict
- 0, // tp_descr_get
- 0, // tp_descr_set
- 0, // tp_dictoffset
- 0, // tp_init
-};
-
-} // namespace python
-} // namespace protobuf
-} // namespace google
diff --git a/python/google/protobuf/pyext/scalar_map_container.cc b/python/google/protobuf/pyext/scalar_map_container.cc
deleted file mode 100644
index 0b0d5a3d..00000000
--- a/python/google/protobuf/pyext/scalar_map_container.cc
+++ /dev/null
@@ -1,542 +0,0 @@
-// Protocol Buffers - Google's data interchange format
-// Copyright 2008 Google Inc. All rights reserved.
-// https://developers.google.com/protocol-buffers/
-//
-// Redistribution and use in source and binary forms, with or without
-// modification, are permitted provided that the following conditions are
-// met:
-//
-// * Redistributions of source code must retain the above copyright
-// notice, this list of conditions and the following disclaimer.
-// * Redistributions in binary form must reproduce the above
-// copyright notice, this list of conditions and the following disclaimer
-// in the documentation and/or other materials provided with the
-// distribution.
-// * Neither the name of Google Inc. nor the names of its
-// contributors may be used to endorse or promote products derived from
-// this software without specific prior written permission.
-//
-// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-// Author: haberman@google.com (Josh Haberman)
-
-#include <google/protobuf/pyext/scalar_map_container.h>
-
-#include <google/protobuf/stubs/logging.h>
-#include <google/protobuf/stubs/common.h>
-#include <google/protobuf/message.h>
-#include <google/protobuf/pyext/message.h>
-#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
-
-namespace google {
-namespace protobuf {
-namespace python {
-
-struct ScalarMapIterator {
- PyObject_HEAD;
-
- // This dict contains the full contents of what we want to iterate over.
- // There's no way to avoid building this, because the list representation
- // (which is canonical) can contain duplicate keys. So at the very least we
- // need a set that lets us skip duplicate keys. And at the point that we're
- // doing that, we might as well just build the actual dict we're iterating
- // over and use dict's built-in iterator.
- PyObject* dict;
-
- // An iterator on dict.
- PyObject* iter;
-
- // A pointer back to the container, so we can notice changes to the version.
- ScalarMapContainer* container;
-
- // The version of the map when we took the iterator to it.
- //
- // We store this so that if the map is modified during iteration we can throw
- // an error.
- uint64 version;
-};
-
-static ScalarMapIterator* GetIter(PyObject* obj) {
- return reinterpret_cast<ScalarMapIterator*>(obj);
-}
-
-namespace scalar_map_container {
-
-static ScalarMapContainer* GetMap(PyObject* obj) {
- return reinterpret_cast<ScalarMapContainer*>(obj);
-}
-
-// The private constructor of ScalarMapContainer objects.
-PyObject *NewContainer(
- CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor) {
- if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
- return NULL;
- }
-
-#if PY_MAJOR_VERSION >= 3
- ScopedPyObjectPtr obj(PyType_GenericAlloc(
- reinterpret_cast<PyTypeObject *>(ScalarMapContainer_Type), 0));
-#else
- ScopedPyObjectPtr obj(PyType_GenericAlloc(&ScalarMapContainer_Type, 0));
-#endif
- if (obj.get() == NULL) {
- return PyErr_Format(PyExc_RuntimeError,
- "Could not allocate new container.");
- }
-
- ScalarMapContainer* self = GetMap(obj.get());
-
- self->message = parent->message;
- self->parent = parent;
- self->parent_field_descriptor = parent_field_descriptor;
- self->owner = parent->owner;
- self->version = 0;
-
- self->key_field_descriptor =
- parent_field_descriptor->message_type()->FindFieldByName("key");
- self->value_field_descriptor =
- parent_field_descriptor->message_type()->FindFieldByName("value");
-
- if (self->key_field_descriptor == NULL ||
- self->value_field_descriptor == NULL) {
- return PyErr_Format(PyExc_KeyError,
- "Map entry descriptor did not have key/value fields");
- }
-
- return obj.release();
-}
-
-// Initializes the underlying Message object of "to" so it becomes a new parent
-// repeated scalar, and copies all the values from "from" to it. A child scalar
-// container can be released by passing it as both from and to (e.g. making it
-// the recipient of the new parent message and copying the values from itself).
-static int InitializeAndCopyToParentContainer(
- ScalarMapContainer* from,
- ScalarMapContainer* to) {
- // For now we require from == to, re-evaluate if we want to support deep copy
- // as in repeated_scalar_container.cc.
- GOOGLE_DCHECK(from == to);
- Message* old_message = from->message;
- Message* new_message = old_message->New();
- to->parent = NULL;
- to->parent_field_descriptor = from->parent_field_descriptor;
- to->message = new_message;
- to->owner.reset(new_message);
-
- vector<const FieldDescriptor*> fields;
- fields.push_back(from->parent_field_descriptor);
- old_message->GetReflection()->SwapFields(old_message, new_message, fields);
- return 0;
-}
-
-int Release(ScalarMapContainer* self) {
- return InitializeAndCopyToParentContainer(self, self);
-}
-
-void SetOwner(ScalarMapContainer* self,
- const shared_ptr<Message>& new_owner) {
- self->owner = new_owner;
-}
-
-Py_ssize_t Length(PyObject* _self) {
- ScalarMapContainer* self = GetMap(_self);
- google::protobuf::Message* message = self->message;
- return message->GetReflection()->FieldSize(*message,
- self->parent_field_descriptor);
-}
-
-int MapKeyMatches(ScalarMapContainer* self, const Message* entry,
- PyObject* key) {
- // TODO(haberman): do we need more strict type checking?
- ScopedPyObjectPtr entry_key(
- cmessage::InternalGetScalar(entry, self->key_field_descriptor));
- int ret = PyObject_RichCompareBool(key, entry_key.get(), Py_EQ);
- return ret;
-}
-
-PyObject* GetItem(PyObject* _self, PyObject* key) {
- ScalarMapContainer* self = GetMap(_self);
-
- Message* message = self->message;
- const Reflection* reflection = message->GetReflection();
-
- // Right now the Reflection API doesn't support map lookup, so we implement it
- // via linear search.
- //
- // TODO(haberman): add lookup API to Reflection API.
- size_t size = reflection->FieldSize(*message, self->parent_field_descriptor);
- for (int i = size - 1; i >= 0; i--) {
- const Message& entry = reflection->GetRepeatedMessage(
- *message, self->parent_field_descriptor, i);
- int matches = MapKeyMatches(self, &entry, key);
- if (matches < 0) return NULL;
- if (matches) {
- return cmessage::InternalGetScalar(&entry, self->value_field_descriptor);
- }
- }
-
- // Need to add a new entry.
- Message* entry =
- reflection->AddMessage(message, self->parent_field_descriptor);
- PyObject* ret = NULL;
-
- if (cmessage::InternalSetNonOneofScalar(entry, self->key_field_descriptor,
- key) >= 0) {
- ret = cmessage::InternalGetScalar(entry, self->value_field_descriptor);
- }
-
- self->version++;
-
- // If there was a type error above, it set the Python exception.
- return ret;
-}
-
-int SetItem(PyObject *_self, PyObject *key, PyObject *v) {
- ScalarMapContainer* self = GetMap(_self);
- cmessage::AssureWritable(self->parent);
-
- Message* message = self->message;
- const Reflection* reflection = message->GetReflection();
- size_t size =
- reflection->FieldSize(*message, self->parent_field_descriptor);
- self->version++;
-
- if (v) {
- // Set item.
- //
- // Right now the Reflection API doesn't support map lookup, so we implement
- // it via linear search.
- //
- // TODO(haberman): add lookup API to Reflection API.
- for (int i = size - 1; i >= 0; i--) {
- Message* entry = reflection->MutableRepeatedMessage(
- message, self->parent_field_descriptor, i);
- int matches = MapKeyMatches(self, entry, key);
- if (matches < 0) return -1;
- if (matches) {
- return cmessage::InternalSetNonOneofScalar(
- entry, self->value_field_descriptor, v);
- }
- }
-
- // Key is not already present; insert a new entry.
- Message* entry =
- reflection->AddMessage(message, self->parent_field_descriptor);
-
- if (cmessage::InternalSetNonOneofScalar(entry, self->key_field_descriptor,
- key) < 0 ||
- cmessage::InternalSetNonOneofScalar(entry, self->value_field_descriptor,
- v) < 0) {
- reflection->RemoveLast(message, self->parent_field_descriptor);
- return -1;
- }
-
- return 0;
- } else {
- bool found = false;
- for (int i = size - 1; i >= 0; i--) {
- Message* entry = reflection->MutableRepeatedMessage(
- message, self->parent_field_descriptor, i);
- int matches = MapKeyMatches(self, entry, key);
- if (matches < 0) return -1;
- if (matches) {
- found = true;
- if (i != (int)size - 1) {
- reflection->SwapElements(message, self->parent_field_descriptor, i,
- size - 1);
- }
- reflection->RemoveLast(message, self->parent_field_descriptor);
-
- // Can't exit now, the repeated field representation of maps allows
- // duplicate keys, and we have to be sure to remove all of them.
- }
- }
-
- if (found) {
- return 0;
- } else {
- PyErr_Format(PyExc_KeyError, "Key not present in map");
- return -1;
- }
- }
-}
-
-PyObject* GetIterator(PyObject *_self) {
- ScalarMapContainer* self = GetMap(_self);
-
- ScopedPyObjectPtr obj(PyType_GenericAlloc(&ScalarMapIterator_Type, 0));
- if (obj == NULL) {
- return PyErr_Format(PyExc_KeyError, "Could not allocate iterator");
- }
-
- ScalarMapIterator* iter = GetIter(obj.get());
-
- Py_INCREF(self);
- iter->container = self;
- iter->version = self->version;
- iter->dict = PyDict_New();
- if (iter->dict == NULL) {
- return PyErr_Format(PyExc_RuntimeError,
- "Could not allocate dict for iterator.");
- }
-
- // Build the entire map into a dict right now. Start from the beginning so
- // that later entries win in the case of duplicates.
- Message* message = self->message;
- const Reflection* reflection = message->GetReflection();
-
- // Right now the Reflection API doesn't support map lookup, so we implement it
- // via linear search. We need to search from the end because the underlying
- // representation can have duplicates if a user calls MergeFrom(); the last
- // one needs to win.
- //
- // TODO(haberman): add lookup API to Reflection API.
- size_t size =
- reflection->FieldSize(*message, self->parent_field_descriptor);
- for (size_t i = 0; i < size; i++) {
- Message* entry = reflection->MutableRepeatedMessage(
- message, self->parent_field_descriptor, i);
- ScopedPyObjectPtr key(
- cmessage::InternalGetScalar(entry, self->key_field_descriptor));
- ScopedPyObjectPtr val(
- cmessage::InternalGetScalar(entry, self->value_field_descriptor));
- if (PyDict_SetItem(iter->dict, key.get(), val.get()) < 0) {
- return PyErr_Format(PyExc_RuntimeError,
- "SetItem failed in iterator construction.");
- }
- }
-
-
- iter->iter = PyObject_GetIter(iter->dict);
-
-
- return obj.release();
-}
-
-PyObject* Clear(PyObject* _self) {
- ScalarMapContainer* self = GetMap(_self);
- cmessage::AssureWritable(self->parent);
- Message* message = self->message;
- const Reflection* reflection = message->GetReflection();
-
- reflection->ClearField(message, self->parent_field_descriptor);
-
- Py_RETURN_NONE;
-}
-
-PyObject* Contains(PyObject* _self, PyObject* key) {
- ScalarMapContainer* self = GetMap(_self);
-
- Message* message = self->message;
- const Reflection* reflection = message->GetReflection();
-
- // Right now the Reflection API doesn't support map lookup, so we implement it
- // via linear search.
- //
- // TODO(haberman): add lookup API to Reflection API.
- size_t size = reflection->FieldSize(*message, self->parent_field_descriptor);
- for (int i = size - 1; i >= 0; i--) {
- const Message& entry = reflection->GetRepeatedMessage(
- *message, self->parent_field_descriptor, i);
- int matches = MapKeyMatches(self, &entry, key);
- if (matches < 0) return NULL;
- if (matches) {
- Py_RETURN_TRUE;
- }
- }
-
- Py_RETURN_FALSE;
-}
-
-PyObject* Get(PyObject* self, PyObject* args) {
- PyObject* key;
- PyObject* default_value = NULL;
- if (PyArg_ParseTuple(args, "O|O", &key, &default_value) < 0) {
- return NULL;
- }
-
- ScopedPyObjectPtr is_present(Contains(self, key));
- if (is_present.get() == NULL) {
- return NULL;
- }
-
- if (PyObject_IsTrue(is_present.get())) {
- return GetItem(self, key);
- } else {
- if (default_value != NULL) {
- Py_INCREF(default_value);
- return default_value;
- } else {
- Py_RETURN_NONE;
- }
- }
-}
-
-static void Dealloc(PyObject* _self) {
- ScalarMapContainer* self = GetMap(_self);
- self->owner.reset();
- Py_TYPE(_self)->tp_free(_self);
-}
-
-static PyMethodDef Methods[] = {
- { "__contains__", Contains, METH_O,
- "Tests whether a key is a member of the map." },
- { "clear", (PyCFunction)Clear, METH_NOARGS,
- "Removes all elements from the map." },
- { "get", Get, METH_VARARGS,
- "Gets the value for the given key if present, or otherwise a default" },
- /*
- { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
- "Makes a deep copy of the class." },
- { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
- "Outputs picklable representation of the repeated field." },
- */
- {NULL, NULL},
-};
-
-} // namespace scalar_map_container
-
-namespace scalar_map_iterator {
-
-static void Dealloc(PyObject* _self) {
- ScalarMapIterator* self = GetIter(_self);
- Py_DECREF(self->dict);
- Py_DECREF(self->iter);
- Py_DECREF(self->container);
- Py_TYPE(_self)->tp_free(_self);
-}
-
-PyObject* IterNext(PyObject* _self) {
- ScalarMapIterator* self = GetIter(_self);
-
- // This won't catch mutations to the map performed by MergeFrom(); no easy way
- // to address that.
- if (self->version != self->container->version) {
- return PyErr_Format(PyExc_RuntimeError,
- "Map modified during iteration.");
- }
-
- return PyIter_Next(self->iter);
-}
-
-} // namespace scalar_map_iterator
-
-
-#if PY_MAJOR_VERSION >= 3
- static PyType_Slot ScalarMapContainer_Type_slots[] = {
- {Py_tp_dealloc, (void *)scalar_map_container::Dealloc},
- {Py_mp_length, (void *)scalar_map_container::Length},
- {Py_mp_subscript, (void *)scalar_map_container::GetItem},
- {Py_mp_ass_subscript, (void *)scalar_map_container::SetItem},
- {Py_tp_methods, (void *)scalar_map_container::Methods},
- {Py_tp_iter, (void *)scalar_map_container::GetIterator},
- {0, 0},
- };
-
- PyType_Spec ScalarMapContainer_Type_spec = {
- FULL_MODULE_NAME ".ScalarMapContainer",
- sizeof(ScalarMapContainer),
- 0,
- Py_TPFLAGS_DEFAULT,
- ScalarMapContainer_Type_slots
- };
- PyObject *ScalarMapContainer_Type;
-#else
- static PyMappingMethods MpMethods = {
- scalar_map_container::Length, // mp_length
- scalar_map_container::GetItem, // mp_subscript
- scalar_map_container::SetItem, // mp_ass_subscript
- };
-
- PyTypeObject ScalarMapContainer_Type = {
- PyVarObject_HEAD_INIT(&PyType_Type, 0)
- FULL_MODULE_NAME ".ScalarMapContainer", // tp_name
- sizeof(ScalarMapContainer), // tp_basicsize
- 0, // tp_itemsize
- scalar_map_container::Dealloc, // tp_dealloc
- 0, // tp_print
- 0, // tp_getattr
- 0, // tp_setattr
- 0, // tp_compare
- 0, // tp_repr
- 0, // tp_as_number
- 0, // tp_as_sequence
- &MpMethods, // tp_as_mapping
- 0, // tp_hash
- 0, // tp_call
- 0, // tp_str
- 0, // tp_getattro
- 0, // tp_setattro
- 0, // tp_as_buffer
- Py_TPFLAGS_DEFAULT, // tp_flags
- "A scalar map container", // tp_doc
- 0, // tp_traverse
- 0, // tp_clear
- 0, // tp_richcompare
- 0, // tp_weaklistoffset
- scalar_map_container::GetIterator, // tp_iter
- 0, // tp_iternext
- scalar_map_container::Methods, // tp_methods
- 0, // tp_members
- 0, // tp_getset
- 0, // tp_base
- 0, // tp_dict
- 0, // tp_descr_get
- 0, // tp_descr_set
- 0, // tp_dictoffset
- 0, // tp_init
- };
-#endif
-
-PyTypeObject ScalarMapIterator_Type = {
- PyVarObject_HEAD_INIT(&PyType_Type, 0)
- FULL_MODULE_NAME ".ScalarMapIterator", // tp_name
- sizeof(ScalarMapIterator), // tp_basicsize
- 0, // tp_itemsize
- scalar_map_iterator::Dealloc, // tp_dealloc
- 0, // tp_print
- 0, // tp_getattr
- 0, // tp_setattr
- 0, // tp_compare
- 0, // tp_repr
- 0, // tp_as_number
- 0, // tp_as_sequence
- 0, // tp_as_mapping
- 0, // tp_hash
- 0, // tp_call
- 0, // tp_str
- 0, // tp_getattro
- 0, // tp_setattro
- 0, // tp_as_buffer
- Py_TPFLAGS_DEFAULT, // tp_flags
- "A scalar map iterator", // tp_doc
- 0, // tp_traverse
- 0, // tp_clear
- 0, // tp_richcompare
- 0, // tp_weaklistoffset
- PyObject_SelfIter, // tp_iter
- scalar_map_iterator::IterNext, // tp_iternext
- 0, // tp_methods
- 0, // tp_members
- 0, // tp_getset
- 0, // tp_base
- 0, // tp_dict
- 0, // tp_descr_get
- 0, // tp_descr_set
- 0, // tp_dictoffset
- 0, // tp_init
-};
-
-} // namespace python
-} // namespace protobuf
-} // namespace google
diff --git a/python/google/protobuf/pyext/scalar_map_container.h b/python/google/protobuf/pyext/scalar_map_container.h
deleted file mode 100644
index 4d663b88..00000000
--- a/python/google/protobuf/pyext/scalar_map_container.h
+++ /dev/null
@@ -1,119 +0,0 @@
-// Protocol Buffers - Google's data interchange format
-// Copyright 2008 Google Inc. All rights reserved.
-// https://developers.google.com/protocol-buffers/
-//
-// Redistribution and use in source and binary forms, with or without
-// modification, are permitted provided that the following conditions are
-// met:
-//
-// * Redistributions of source code must retain the above copyright
-// notice, this list of conditions and the following disclaimer.
-// * Redistributions in binary form must reproduce the above
-// copyright notice, this list of conditions and the following disclaimer
-// in the documentation and/or other materials provided with the
-// distribution.
-// * Neither the name of Google Inc. nor the names of its
-// contributors may be used to endorse or promote products derived from
-// this software without specific prior written permission.
-//
-// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_SCALAR_MAP_CONTAINER_H__
-#define GOOGLE_PROTOBUF_PYTHON_CPP_SCALAR_MAP_CONTAINER_H__
-
-#include <Python.h>
-
-#include <memory>
-#ifndef _SHARED_PTR_H
-#include <google/protobuf/stubs/shared_ptr.h>
-#endif
-
-#include <google/protobuf/descriptor.h>
-
-namespace google {
-namespace protobuf {
-
-class Message;
-
-#ifdef _SHARED_PTR_H
-using std::shared_ptr;
-#else
-using internal::shared_ptr;
-#endif
-
-namespace python {
-
-struct CMessage;
-
-struct ScalarMapContainer {
- PyObject_HEAD;
-
- // This is the top-level C++ Message object that owns the whole
- // proto tree. Every Python ScalarMapContainer holds a
- // reference to it in order to keep it alive as long as there's a
- // Python object that references any part of the tree.
- shared_ptr<Message> owner;
-
- // Pointer to the C++ Message that contains this container. The
- // ScalarMapContainer does not own this pointer.
- Message* message;
-
- // Weak reference to a parent CMessage object (i.e. may be NULL.)
- //
- // Used to make sure all ancestors are also mutable when first
- // modifying the container.
- CMessage* parent;
-
- // Pointer to the parent's descriptor that describes this
- // field. Used together with the parent's message when making a
- // default message instance mutable.
- // The pointer is owned by the global DescriptorPool.
- const FieldDescriptor* parent_field_descriptor;
- const FieldDescriptor* key_field_descriptor;
- const FieldDescriptor* value_field_descriptor;
-
- // We bump this whenever we perform a mutation, to invalidate existing
- // iterators.
- uint64 version;
-};
-
-#if PY_MAJOR_VERSION >= 3
- extern PyObject *ScalarMapContainer_Type;
- extern PyType_Spec ScalarMapContainer_Type_spec;
-#else
- extern PyTypeObject ScalarMapContainer_Type;
-#endif
-extern PyTypeObject ScalarMapIterator_Type;
-
-namespace scalar_map_container {
-
-// Builds a ScalarMapContainer object, from a parent message and a
-// field descriptor.
-extern PyObject *NewContainer(
- CMessage* parent, const FieldDescriptor* parent_field_descriptor);
-
-// Releases the messages in the container to a new message.
-//
-// Returns 0 on success, -1 on failure.
-int Release(ScalarMapContainer* self);
-
-// Set the owner field of self and any children of self.
-void SetOwner(ScalarMapContainer* self,
- const shared_ptr<Message>& new_owner);
-
-} // namespace scalar_map_container
-} // namespace python
-} // namespace protobuf
-
-} // namespace google
-#endif // GOOGLE_PROTOBUF_PYTHON_CPP_SCALAR_MAP_CONTAINER_H__
diff --git a/python/google/protobuf/symbol_database.py b/python/google/protobuf/symbol_database.py
index b81ef4d7..87760f26 100644
--- a/python/google/protobuf/symbol_database.py
+++ b/python/google/protobuf/symbol_database.py
@@ -60,7 +60,6 @@ Example usage:
"""
-from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool
@@ -73,37 +72,12 @@ class SymbolDatabase(object):
buffer types used within a program.
"""
- # pylint: disable=protected-access
- if _descriptor._USE_C_DESCRIPTORS:
-
- def __new__(cls):
- raise TypeError("Instances of SymbolDatabase cannot be created")
-
- @classmethod
- def _CreateDefaultDatabase(cls):
- self = object.__new__(cls) # Bypass the __new__ above.
- # Don't call __init__() and initialize here.
- self._symbols = {}
- self._symbols_by_file = {}
- # As of today all descriptors are registered and retrieved from
- # _message.default_pool (see FileDescriptor.__new__), so it's not
- # necessary to use another pool.
- self.pool = _descriptor._message.default_pool
- return self
- # pylint: enable=protected-access
-
- else:
-
- @classmethod
- def _CreateDefaultDatabase(cls):
- return cls()
-
- def __init__(self):
+ def __init__(self, pool=None):
"""Constructor."""
self._symbols = {}
self._symbols_by_file = {}
- self.pool = descriptor_pool.DescriptorPool()
+ self.pool = pool or descriptor_pool.Default()
def RegisterMessage(self, message):
"""Registers the given message type in the local database.
@@ -203,7 +177,7 @@ class SymbolDatabase(object):
result.update(self._symbols_by_file[f])
return result
-_DEFAULT = SymbolDatabase._CreateDefaultDatabase()
+_DEFAULT = SymbolDatabase(pool=descriptor_pool.Default())
def Default():
diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py
index e4fadf09..8d256076 100755
--- a/python/google/protobuf/text_format.py
+++ b/python/google/protobuf/text_format.py
@@ -66,6 +66,7 @@ _FLOAT_INFINITY = re.compile('-?inf(?:inity)?f?', re.IGNORECASE)
_FLOAT_NAN = re.compile('nanf?', re.IGNORECASE)
_FLOAT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_FLOAT,
descriptor.FieldDescriptor.CPPTYPE_DOUBLE])
+_QUOTES = frozenset(("'", '"'))
class Error(Exception):
@@ -73,7 +74,8 @@ class Error(Exception):
class ParseError(Error):
- """Thrown in case of ASCII parsing error."""
+ """Thrown in case of text parsing error."""
+
class TextWriter(object):
def __init__(self, as_utf8):
@@ -102,7 +104,8 @@ def MessageToString(message, as_utf8=False, as_one_line=False,
Floating point values can be formatted compactly with 15 digits of
precision (which is the most that IEEE 754 "double" can guarantee)
- using float_format='.15g'.
+ using float_format='.15g'. To ensure that converting to text and back to a
+ proto will result in an identical value, float_format='.17g' should be used.
Args:
message: The protocol buffers message.
@@ -130,11 +133,13 @@ def MessageToString(message, as_utf8=False, as_one_line=False,
return result.rstrip()
return result
+
def _IsMapEntry(field):
return (field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
field.message_type.has_options and
field.message_type.GetOptions().map_entry)
+
def PrintMessage(message, out, indent=0, as_utf8=False, as_one_line=False,
pointy_brackets=False, use_index_order=False,
float_format=None):
@@ -166,17 +171,18 @@ def PrintMessage(message, out, indent=0, as_utf8=False, as_one_line=False,
use_index_order=use_index_order,
float_format=float_format)
+
def PrintField(field, value, out, indent=0, as_utf8=False, as_one_line=False,
pointy_brackets=False, use_index_order=False, float_format=None):
"""Print a single field name/value pair. For repeated fields, the value
- should be a single element."""
+ should be a single element.
+ """
out.write(' ' * indent)
if field.is_extension:
out.write('[')
if (field.containing_type.GetOptions().message_set_wire_format and
field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
- field.message_type == field.extension_scope and
field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL):
out.write(field.message_type.full_name)
else:
@@ -262,95 +268,113 @@ def PrintFieldValue(field, value, out, indent=0, as_utf8=False,
out.write(str(value))
-def Parse(text, message):
- """Parses an ASCII representation of a protocol message into a message.
+def Parse(text, message, allow_unknown_extension=False):
+ """Parses an text representation of a protocol message into a message.
Args:
- text: Message ASCII representation.
+ text: Message text representation.
message: A protocol buffer message to merge into.
+ allow_unknown_extension: if True, skip over missing extensions and keep
+ parsing
Returns:
The same message passed as argument.
Raises:
- ParseError: On ASCII parsing problems.
+ ParseError: On text parsing problems.
"""
- if not isinstance(text, str): text = text.decode('utf-8')
- return ParseLines(text.split('\n'), message)
+ if not isinstance(text, str):
+ text = text.decode('utf-8')
+ return ParseLines(text.split('\n'), message, allow_unknown_extension)
-def Merge(text, message):
- """Parses an ASCII representation of a protocol message into a message.
+def Merge(text, message, allow_unknown_extension=False):
+ """Parses an text representation of a protocol message into a message.
Like Parse(), but allows repeated values for a non-repeated field, and uses
the last one.
Args:
- text: Message ASCII representation.
+ text: Message text representation.
message: A protocol buffer message to merge into.
+ allow_unknown_extension: if True, skip over missing extensions and keep
+ parsing
Returns:
The same message passed as argument.
Raises:
- ParseError: On ASCII parsing problems.
+ ParseError: On text parsing problems.
"""
- return MergeLines(text.split('\n'), message)
+ return MergeLines(text.split('\n'), message, allow_unknown_extension)
-def ParseLines(lines, message):
- """Parses an ASCII representation of a protocol message into a message.
+def ParseLines(lines, message, allow_unknown_extension=False):
+ """Parses an text representation of a protocol message into a message.
Args:
- lines: An iterable of lines of a message's ASCII representation.
+ lines: An iterable of lines of a message's text representation.
message: A protocol buffer message to merge into.
+ allow_unknown_extension: if True, skip over missing extensions and keep
+ parsing
Returns:
The same message passed as argument.
Raises:
- ParseError: On ASCII parsing problems.
+ ParseError: On text parsing problems.
"""
- _ParseOrMerge(lines, message, False)
+ _ParseOrMerge(lines, message, False, allow_unknown_extension)
return message
-def MergeLines(lines, message):
- """Parses an ASCII representation of a protocol message into a message.
+def MergeLines(lines, message, allow_unknown_extension=False):
+ """Parses an text representation of a protocol message into a message.
Args:
- lines: An iterable of lines of a message's ASCII representation.
+ lines: An iterable of lines of a message's text representation.
message: A protocol buffer message to merge into.
+ allow_unknown_extension: if True, skip over missing extensions and keep
+ parsing
Returns:
The same message passed as argument.
Raises:
- ParseError: On ASCII parsing problems.
+ ParseError: On text parsing problems.
"""
- _ParseOrMerge(lines, message, True)
+ _ParseOrMerge(lines, message, True, allow_unknown_extension)
return message
-def _ParseOrMerge(lines, message, allow_multiple_scalars):
- """Converts an ASCII representation of a protocol message into a message.
+def _ParseOrMerge(lines,
+ message,
+ allow_multiple_scalars,
+ allow_unknown_extension=False):
+ """Converts an text representation of a protocol message into a message.
Args:
- lines: Lines of a message's ASCII representation.
+ lines: Lines of a message's text representation.
message: A protocol buffer message to merge into.
allow_multiple_scalars: Determines if repeated values for a non-repeated
field are permitted, e.g., the string "foo: 1 foo: 2" for a
required/optional field named "foo".
+ allow_unknown_extension: if True, skip over missing extensions and keep
+ parsing
Raises:
- ParseError: On ASCII parsing problems.
+ ParseError: On text parsing problems.
"""
tokenizer = _Tokenizer(lines)
while not tokenizer.AtEnd():
- _MergeField(tokenizer, message, allow_multiple_scalars)
+ _MergeField(tokenizer, message, allow_multiple_scalars,
+ allow_unknown_extension)
-def _MergeField(tokenizer, message, allow_multiple_scalars):
+def _MergeField(tokenizer,
+ message,
+ allow_multiple_scalars,
+ allow_unknown_extension=False):
"""Merges a single protocol message field into a message.
Args:
@@ -359,9 +383,11 @@ def _MergeField(tokenizer, message, allow_multiple_scalars):
allow_multiple_scalars: Determines if repeated values for a non-repeated
field are permitted, e.g., the string "foo: 1 foo: 2" for a
required/optional field named "foo".
+ allow_unknown_extension: if True, skip over missing extensions and keep
+ parsing
Raises:
- ParseError: In case of ASCII parsing problems.
+ ParseError: In case of text parsing problems.
"""
message_descriptor = message.DESCRIPTOR
if (hasattr(message_descriptor, 'syntax') and
@@ -383,13 +409,18 @@ def _MergeField(tokenizer, message, allow_multiple_scalars):
field = message.Extensions._FindExtensionByName(name)
# pylint: enable=protected-access
if not field:
- raise tokenizer.ParseErrorPreviousToken(
- 'Extension "%s" not registered.' % name)
+ if allow_unknown_extension:
+ field = None
+ else:
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Extension "%s" not registered.' % name)
elif message_descriptor != field.containing_type:
raise tokenizer.ParseErrorPreviousToken(
'Extension "%s" does not extend message type "%s".' % (
name, message_descriptor.full_name))
+
tokenizer.Consume(']')
+
else:
name = tokenizer.ConsumeIdentifier()
field = message_descriptor.fields_by_name.get(name, None)
@@ -411,7 +442,7 @@ def _MergeField(tokenizer, message, allow_multiple_scalars):
'Message type "%s" has no field named "%s".' % (
message_descriptor.full_name, name))
- if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ if field and field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
is_map_entry = _IsMapEntry(field)
tokenizer.TryConsume(':')
@@ -438,7 +469,8 @@ def _MergeField(tokenizer, message, allow_multiple_scalars):
while not tokenizer.TryConsume(end_token):
if tokenizer.AtEnd():
raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % (end_token))
- _MergeField(tokenizer, sub_message, allow_multiple_scalars)
+ _MergeField(tokenizer, sub_message, allow_multiple_scalars,
+ allow_unknown_extension)
if is_map_entry:
value_cpptype = field.message_type.fields_by_name['value'].cpp_type
@@ -447,8 +479,63 @@ def _MergeField(tokenizer, message, allow_multiple_scalars):
value.MergeFrom(sub_message.value)
else:
getattr(message, field.name)[sub_message.key] = sub_message.value
+ elif field:
+ tokenizer.Consume(':')
+ if (field.label == descriptor.FieldDescriptor.LABEL_REPEATED and
+ tokenizer.TryConsume('[')):
+ # Short repeated format, e.g. "foo: [1, 2, 3]"
+ while True:
+ _MergeScalarField(tokenizer, message, field, allow_multiple_scalars)
+ if tokenizer.TryConsume(']'):
+ break
+ tokenizer.Consume(',')
+ else:
+ _MergeScalarField(tokenizer, message, field, allow_multiple_scalars)
+ else: # Proto field is unknown.
+ assert allow_unknown_extension
+ _SkipFieldContents(tokenizer)
+
+ # For historical reasons, fields may optionally be separated by commas or
+ # semicolons.
+ if not tokenizer.TryConsume(','):
+ tokenizer.TryConsume(';')
+
+
+def _SkipFieldContents(tokenizer):
+ """Skips over contents (value or message) of a field.
+
+ Args:
+ tokenizer: A tokenizer to parse the field name and values.
+ """
+ # Try to guess the type of this field.
+ # If this field is not a message, there should be a ":" between the
+ # field name and the field value and also the field value should not
+ # start with "{" or "<" which indicates the beginning of a message body.
+ # If there is no ":" or there is a "{" or "<" after ":", this field has
+ # to be a message or the input is ill-formed.
+ if tokenizer.TryConsume(':') and not tokenizer.LookingAt(
+ '{') and not tokenizer.LookingAt('<'):
+ _SkipFieldValue(tokenizer)
+ else:
+ _SkipFieldMessage(tokenizer)
+
+
+def _SkipField(tokenizer):
+ """Skips over a complete field (name and value/message).
+
+ Args:
+ tokenizer: A tokenizer to parse the field name and values.
+ """
+ if tokenizer.TryConsume('['):
+ # Consume extension name.
+ tokenizer.ConsumeIdentifier()
+ while tokenizer.TryConsume('.'):
+ tokenizer.ConsumeIdentifier()
+ tokenizer.Consume(']')
else:
- _MergeScalarField(tokenizer, message, field, allow_multiple_scalars)
+ tokenizer.ConsumeIdentifier()
+
+ _SkipFieldContents(tokenizer)
# For historical reasons, fields may optionally be separated by commas or
# semicolons.
@@ -456,6 +543,48 @@ def _MergeField(tokenizer, message, allow_multiple_scalars):
tokenizer.TryConsume(';')
+def _SkipFieldMessage(tokenizer):
+ """Skips over a field message.
+
+ Args:
+ tokenizer: A tokenizer to parse the field name and values.
+ """
+
+ if tokenizer.TryConsume('<'):
+ delimiter = '>'
+ else:
+ tokenizer.Consume('{')
+ delimiter = '}'
+
+ while not tokenizer.LookingAt('>') and not tokenizer.LookingAt('}'):
+ _SkipField(tokenizer)
+
+ tokenizer.Consume(delimiter)
+
+
+def _SkipFieldValue(tokenizer):
+ """Skips over a field value.
+
+ Args:
+ tokenizer: A tokenizer to parse the field name and values.
+
+ Raises:
+ ParseError: In case an invalid field value is found.
+ """
+ # String tokens can come in multiple adjacent string literals.
+ # If we can consume one, consume as many as we can.
+ if tokenizer.TryConsumeString():
+ while tokenizer.TryConsumeString():
+ pass
+ return
+
+ if (not tokenizer.TryConsumeIdentifier() and
+ not tokenizer.TryConsumeInt64() and
+ not tokenizer.TryConsumeUint64() and
+ not tokenizer.TryConsumeFloat()):
+ raise ParseError('Invalid field value: ' + tokenizer.token)
+
+
def _MergeScalarField(tokenizer, message, field, allow_multiple_scalars):
"""Merges a single protocol message scalar field into a message.
@@ -468,10 +597,9 @@ def _MergeScalarField(tokenizer, message, field, allow_multiple_scalars):
required/optional field named "foo".
Raises:
- ParseError: In case of ASCII parsing problems.
+ ParseError: In case of text parsing problems.
RuntimeError: On runtime errors.
"""
- tokenizer.Consume(':')
value = None
if field.type in (descriptor.FieldDescriptor.TYPE_INT32,
@@ -525,7 +653,7 @@ def _MergeScalarField(tokenizer, message, field, allow_multiple_scalars):
class _Tokenizer(object):
- """Protocol buffer ASCII representation tokenizer.
+ """Protocol buffer text representation tokenizer.
This class handles the lower level string parsing by splitting it into
meaningful tokens.
@@ -534,11 +662,13 @@ class _Tokenizer(object):
"""
_WHITESPACE = re.compile('(\\s|(#.*$))+', re.MULTILINE)
- _TOKEN = re.compile(
- '[a-zA-Z_][0-9a-zA-Z_+-]*|' # an identifier
- '[0-9+-][0-9a-zA-Z_.+-]*|' # a number
- '\"([^\"\n\\\\]|\\\\.)*(\"|\\\\?$)|' # a double-quoted string
- '\'([^\'\n\\\\]|\\\\.)*(\'|\\\\?$)') # a single-quoted string
+ _TOKEN = re.compile('|'.join([
+ r'[a-zA-Z_][0-9a-zA-Z_+-]*', # an identifier
+ r'([0-9+-]|(\.[0-9]))[0-9a-zA-Z_.+-]*', # a number
+ ] + [ # quoted str for each quote mark
+ r'{qt}([^{qt}\n\\]|\\.)*({qt}|\\?$)'.format(qt=mark) for mark in _QUOTES
+ ]))
+
_IDENTIFIER = re.compile(r'\w+')
def __init__(self, lines):
@@ -555,6 +685,9 @@ class _Tokenizer(object):
self._SkipWhitespace()
self.NextToken()
+ def LookingAt(self, token):
+ return self.token == token
+
def AtEnd(self):
"""Checks the end of the text was reached.
@@ -610,6 +743,13 @@ class _Tokenizer(object):
if not self.TryConsume(token):
raise self._ParseError('Expected "%s".' % token)
+ def TryConsumeIdentifier(self):
+ try:
+ self.ConsumeIdentifier()
+ return True
+ except ParseError:
+ return False
+
def ConsumeIdentifier(self):
"""Consumes protocol message field identifier.
@@ -657,6 +797,13 @@ class _Tokenizer(object):
self.NextToken()
return result
+ def TryConsumeInt64(self):
+ try:
+ self.ConsumeInt64()
+ return True
+ except ParseError:
+ return False
+
def ConsumeInt64(self):
"""Consumes a signed 64bit integer number.
@@ -673,6 +820,13 @@ class _Tokenizer(object):
self.NextToken()
return result
+ def TryConsumeUint64(self):
+ try:
+ self.ConsumeUint64()
+ return True
+ except ParseError:
+ return False
+
def ConsumeUint64(self):
"""Consumes an unsigned 64bit integer number.
@@ -689,6 +843,13 @@ class _Tokenizer(object):
self.NextToken()
return result
+ def TryConsumeFloat(self):
+ try:
+ self.ConsumeFloat()
+ return True
+ except ParseError:
+ return False
+
def ConsumeFloat(self):
"""Consumes an floating point number.
@@ -721,6 +882,13 @@ class _Tokenizer(object):
self.NextToken()
return result
+ def TryConsumeString(self):
+ try:
+ self.ConsumeString()
+ return True
+ except ParseError:
+ return False
+
def ConsumeString(self):
"""Consumes a string value.
@@ -746,7 +914,7 @@ class _Tokenizer(object):
ParseError: If a byte array value couldn't be consumed.
"""
the_list = [self._ConsumeSingleByteString()]
- while self.token and self.token[0] in ('\'', '"'):
+ while self.token and self.token[0] in _QUOTES:
the_list.append(self._ConsumeSingleByteString())
return b''.join(the_list)
@@ -757,11 +925,13 @@ class _Tokenizer(object):
tokens which are automatically concatenated, like in C or Python. This
method only consumes one token.
+ Returns:
+ The token parsed.
Raises:
ParseError: When the wrong format data is found.
"""
text = self.token
- if len(text) < 1 or text[0] not in ('\'', '"'):
+ if len(text) < 1 or text[0] not in _QUOTES:
raise self._ParseError('Expected string but found: %r' % (text,))
if len(text) < 2 or text[-1] != text[0]:
diff --git a/python/setup.py b/python/setup.py
index 18865e03..22f6e816 100755
--- a/python/setup.py
+++ b/python/setup.py
@@ -89,6 +89,7 @@ def GenerateUnittestProtos():
generate_proto("../src/google/protobuf/unittest_no_generic_services.proto", False)
generate_proto("../src/google/protobuf/unittest_proto3_arena.proto", False)
generate_proto("../src/google/protobuf/util/json_format_proto3.proto", False)
+ generate_proto("google/protobuf/internal/any_test.proto", False)
generate_proto("google/protobuf/internal/descriptor_pool_test1.proto", False)
generate_proto("google/protobuf/internal/descriptor_pool_test2.proto", False)
generate_proto("google/protobuf/internal/factory_test1.proto", False)