aboutsummaryrefslogtreecommitdiffhomepage
path: root/python/google
diff options
context:
space:
mode:
authorGravatar Adam Cozzette <acozzette@google.com>2016-06-29 15:23:27 -0700
committerGravatar Adam Cozzette <acozzette@google.com>2016-06-29 15:38:03 -0700
commitd64a2d9941c36a7bc2a7959ea10ab8363192ac14 (patch)
tree52330d146ad63d3d70f3baade00d5d1fea8f5e0c /python/google
parentc18aa7795a2e02ef700ff8b039d94ecdcc33432f (diff)
Integrated internal changes from Google
This includes all internal changes from around May 20 to now.
Diffstat (limited to 'python/google')
-rwxr-xr-xpython/google/protobuf/descriptor.py42
-rw-r--r--python/google/protobuf/descriptor_pool.py65
-rwxr-xr-xpython/google/protobuf/internal/containers.py6
-rw-r--r--python/google/protobuf/internal/descriptor_pool_test.py18
-rwxr-xr-xpython/google/protobuf/internal/descriptor_test.py41
-rw-r--r--python/google/protobuf/internal/file_options_test.proto43
-rw-r--r--python/google/protobuf/internal/json_format_test.py13
-rwxr-xr-xpython/google/protobuf/internal/message_test.py4
-rwxr-xr-xpython/google/protobuf/internal/text_format_test.py624
-rw-r--r--python/google/protobuf/json_format.py762
-rw-r--r--python/google/protobuf/pyext/descriptor.cc277
-rw-r--r--python/google/protobuf/pyext/descriptor.h6
-rw-r--r--python/google/protobuf/pyext/descriptor_containers.cc158
-rw-r--r--python/google/protobuf/pyext/descriptor_containers.h8
-rw-r--r--python/google/protobuf/pyext/descriptor_pool.cc38
-rw-r--r--python/google/protobuf/pyext/map_container.cc1
-rw-r--r--python/google/protobuf/pyext/message.cc66
-rw-r--r--python/google/protobuf/pyext/message.h6
-rw-r--r--python/google/protobuf/pyext/message_module.cc88
-rwxr-xr-xpython/google/protobuf/text_format.py613
20 files changed, 2007 insertions, 872 deletions
diff --git a/python/google/protobuf/descriptor.py b/python/google/protobuf/descriptor.py
index 3209b34d..2eba1232 100755
--- a/python/google/protobuf/descriptor.py
+++ b/python/google/protobuf/descriptor.py
@@ -258,7 +258,7 @@ class Descriptor(_NestedDescriptorBase):
def __new__(cls, name, full_name, filename, containing_type, fields,
nested_types, enum_types, extensions, options=None,
is_extendable=True, extension_ranges=None, oneofs=None,
- file=None, serialized_start=None, serialized_end=None,
+ file=None, serialized_start=None, serialized_end=None, # pylint: disable=redefined-builtin
syntax=None):
_message.Message._CheckCalledFromGeneratedFile()
return _message.default_pool.FindMessageTypeByName(full_name)
@@ -269,8 +269,8 @@ class Descriptor(_NestedDescriptorBase):
def __init__(self, name, full_name, filename, containing_type, fields,
nested_types, enum_types, extensions, options=None,
is_extendable=True, extension_ranges=None, oneofs=None,
- file=None, serialized_start=None, serialized_end=None,
- syntax=None): # pylint:disable=redefined-builtin
+ file=None, serialized_start=None, serialized_end=None, # pylint: disable=redefined-builtin
+ syntax=None):
"""Arguments to __init__() are as described in the description
of Descriptor fields above.
@@ -665,7 +665,7 @@ class EnumValueDescriptor(DescriptorBase):
self.type = type
-class OneofDescriptor(object):
+class OneofDescriptor(DescriptorBase):
"""Descriptor for a oneof field.
name: (str) Name of the oneof field.
@@ -682,12 +682,15 @@ class OneofDescriptor(object):
if _USE_C_DESCRIPTORS:
_C_DESCRIPTOR_CLASS = _message.OneofDescriptor
- def __new__(cls, name, full_name, index, containing_type, fields):
+ def __new__(
+ cls, name, full_name, index, containing_type, fields, options=None):
_message.Message._CheckCalledFromGeneratedFile()
return _message.default_pool.FindOneofByName(full_name)
- def __init__(self, name, full_name, index, containing_type, fields):
+ def __init__(
+ self, name, full_name, index, containing_type, fields, options=None):
"""Arguments are as described in the attribute description above."""
+ super(OneofDescriptor, self).__init__(options, 'OneofOptions')
self.name = name
self.full_name = full_name
self.index = index
@@ -705,11 +708,22 @@ class ServiceDescriptor(_NestedDescriptorBase):
definition appears withing the .proto file.
methods: (list of MethodDescriptor) List of methods provided by this
service.
+ methods_by_name: (dict str -> MethodDescriptor) Same MethodDescriptor
+ objects as in |methods_by_name|, but indexed by "name" attribute in each
+ MethodDescriptor.
options: (descriptor_pb2.ServiceOptions) Service options message or
None to use default service options.
file: (FileDescriptor) Reference to file info.
"""
+ if _USE_C_DESCRIPTORS:
+ _C_DESCRIPTOR_CLASS = _message.ServiceDescriptor
+
+ def __new__(cls, name, full_name, index, methods, options=None, file=None, # pylint: disable=redefined-builtin
+ serialized_start=None, serialized_end=None):
+ _message.Message._CheckCalledFromGeneratedFile() # pylint: disable=protected-access
+ return _message.default_pool.FindServiceByName(full_name)
+
def __init__(self, name, full_name, index, methods, options=None, file=None,
serialized_start=None, serialized_end=None):
super(ServiceDescriptor, self).__init__(
@@ -718,16 +732,14 @@ class ServiceDescriptor(_NestedDescriptorBase):
serialized_end=serialized_end)
self.index = index
self.methods = methods
+ self.methods_by_name = dict((m.name, m) for m in methods)
# Set the containing service for each method in this service.
for method in self.methods:
method.containing_service = self
def FindMethodByName(self, name):
"""Searches for the specified method, and returns its descriptor."""
- for method in self.methods:
- if name == method.name:
- return method
- return None
+ return self.methods_by_name.get(name, None)
def CopyToProto(self, proto):
"""Copies this to a descriptor_pb2.ServiceDescriptorProto.
@@ -754,6 +766,14 @@ class MethodDescriptor(DescriptorBase):
None to use default method options.
"""
+ if _USE_C_DESCRIPTORS:
+ _C_DESCRIPTOR_CLASS = _message.MethodDescriptor
+
+ def __new__(cls, name, full_name, index, containing_service,
+ input_type, output_type, options=None):
+ _message.Message._CheckCalledFromGeneratedFile() # pylint: disable=protected-access
+ return _message.default_pool.FindMethodByName(full_name)
+
def __init__(self, name, full_name, index, containing_service,
input_type, output_type, options=None):
"""The arguments are as described in the description of MethodDescriptor
@@ -788,6 +808,7 @@ class FileDescriptor(DescriptorBase):
message_types_by_name: Dict of message names of their descriptors.
enum_types_by_name: Dict of enum names and their descriptors.
extensions_by_name: Dict of extension names and their descriptors.
+ services_by_name: Dict of services names and their descriptors.
pool: the DescriptorPool this descriptor belongs to. When not passed to the
constructor, the global default pool is used.
"""
@@ -825,6 +846,7 @@ class FileDescriptor(DescriptorBase):
self.enum_types_by_name = {}
self.extensions_by_name = {}
+ self.services_by_name = {}
self.dependencies = (dependencies or [])
self.public_dependencies = (public_dependencies or [])
diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py
index 20a33701..5c055ab9 100644
--- a/python/google/protobuf/descriptor_pool.py
+++ b/python/google/protobuf/descriptor_pool.py
@@ -394,6 +394,11 @@ class DescriptorPool(object):
desc_proto_prefix, desc_proto.name, scope)
file_descriptor.message_types_by_name[desc_proto.name] = desc
+ for index, service_proto in enumerate(file_proto.service):
+ file_descriptor.services_by_name[service_proto.name] = (
+ self._MakeServiceDescriptor(service_proto, index, scope,
+ file_proto.package, file_descriptor))
+
self.Add(file_proto)
self._file_descriptors[file_proto.name] = file_descriptor
@@ -441,7 +446,7 @@ class DescriptorPool(object):
for index, extension in enumerate(desc_proto.extension)]
oneofs = [
descriptor.OneofDescriptor(desc.name, '.'.join((desc_name, desc.name)),
- index, None, [])
+ index, None, [], desc.options)
for index, desc in enumerate(desc_proto.oneof_decl)]
extension_ranges = [(r.start, r.end) for r in desc_proto.extension_range]
if extension_ranges:
@@ -679,6 +684,64 @@ class DescriptorPool(object):
options=value_proto.options,
type=None)
+ def _MakeServiceDescriptor(self, service_proto, service_index, scope,
+ package, file_desc):
+ """Make a protobuf ServiceDescriptor given a ServiceDescriptorProto.
+
+ Args:
+ service_proto: The descriptor_pb2.ServiceDescriptorProto protobuf message.
+ service_index: The index of the service in the File.
+ scope: Dict mapping short and full symbols to message and enum types.
+ package: Optional package name for the new message EnumDescriptor.
+ file_desc: The file containing the service descriptor.
+
+ Returns:
+ The added descriptor.
+ """
+
+ if package:
+ service_name = '.'.join((package, service_proto.name))
+ else:
+ service_name = service_proto.name
+
+ methods = [self._MakeMethodDescriptor(method_proto, service_name, package,
+ scope, index)
+ for index, method_proto in enumerate(service_proto.method)]
+ desc = descriptor.ServiceDescriptor(name=service_proto.name,
+ full_name=service_name,
+ index=service_index,
+ methods=methods,
+ options=service_proto.options,
+ file=file_desc)
+ return desc
+
+ def _MakeMethodDescriptor(self, method_proto, service_name, package, scope,
+ index):
+ """Creates a method descriptor from a MethodDescriptorProto.
+
+ Args:
+ method_proto: The proto describing the method.
+ service_name: The name of the containing service.
+ package: Optional package name to look up for types.
+ scope: Scope containing available types.
+ index: Index of the method in the service.
+
+ Returns:
+ An initialized MethodDescriptor object.
+ """
+ full_name = '.'.join((service_name, method_proto.name))
+ input_type = self._GetTypeFromScope(
+ package, method_proto.input_type, scope)
+ output_type = self._GetTypeFromScope(
+ package, method_proto.output_type, scope)
+ return descriptor.MethodDescriptor(name=method_proto.name,
+ full_name=full_name,
+ index=index,
+ containing_service=None,
+ input_type=input_type,
+ output_type=output_type,
+ options=method_proto.options)
+
def _ExtractSymbols(self, descriptors):
"""Pulls out all the symbols from descriptor protos.
diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py
index 97cdd848..ce46d08c 100755
--- a/python/google/protobuf/internal/containers.py
+++ b/python/google/protobuf/internal/containers.py
@@ -594,7 +594,11 @@ class MessageMap(MutableMapping):
def MergeFrom(self, other):
for key in other:
- self[key].MergeFrom(other[key])
+ # According to documentation: "When parsing from the wire or when merging,
+ # if there are duplicate map keys the last key seen is used".
+ if key in self:
+ del self[key]
+ self[key].CopyFrom(other[key])
# self._message_listener.Modified() not required here, because
# mutations to submessages already propagate.
diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py
index 6a13e0bc..3c8c7935 100644
--- a/python/google/protobuf/internal/descriptor_pool_test.py
+++ b/python/google/protobuf/internal/descriptor_pool_test.py
@@ -51,6 +51,7 @@ from google.protobuf.internal import descriptor_pool_test1_pb2
from google.protobuf.internal import descriptor_pool_test2_pb2
from google.protobuf.internal import factory_test1_pb2
from google.protobuf.internal import factory_test2_pb2
+from google.protobuf.internal import file_options_test_pb2
from google.protobuf.internal import more_messages_pb2
from google.protobuf import descriptor
from google.protobuf import descriptor_database
@@ -630,6 +631,23 @@ class AddDescriptorTest(unittest.TestCase):
self.assertEqual(pool.FindMessageTypeByName('package.Message').name,
'Message')
+ def testFileDescriptorOptionsWithCustomDescriptorPool(self):
+ # Create a descriptor pool, and add a new FileDescriptorProto to it.
+ pool = descriptor_pool.DescriptorPool()
+ file_name = 'file_descriptor_options_with_custom_descriptor_pool.proto'
+ file_descriptor_proto = descriptor_pb2.FileDescriptorProto(name=file_name)
+ extension_id = file_options_test_pb2.foo_options
+ file_descriptor_proto.options.Extensions[extension_id].foo_name = 'foo'
+ pool.Add(file_descriptor_proto)
+ # The options set on the FileDescriptorProto should be available in the
+ # descriptor even if they contain extensions that cannot be deserialized
+ # using the pool.
+ file_descriptor = pool.FindFileByName(file_name)
+ options = file_descriptor.GetOptions()
+ self.assertEqual('foo', options.Extensions[extension_id].foo_name)
+ # The object returned by GetOptions() is cached.
+ self.assertIs(options, file_descriptor.GetOptions())
+
@unittest.skipIf(
api_implementation.Type() != 'cpp',
diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py
index b8e75553..623198c8 100755
--- a/python/google/protobuf/internal/descriptor_test.py
+++ b/python/google/protobuf/internal/descriptor_test.py
@@ -77,27 +77,24 @@ class DescriptorTest(unittest.TestCase):
enum_proto.value.add(name='FOREIGN_BAR', number=5)
enum_proto.value.add(name='FOREIGN_BAZ', number=6)
+ file_proto.message_type.add(name='ResponseMessage')
+ service_proto = file_proto.service.add(
+ name='Service')
+ method_proto = service_proto.method.add(
+ name='CallMethod',
+ input_type='.protobuf_unittest.NestedMessage',
+ output_type='.protobuf_unittest.ResponseMessage')
+
+ # Note: Calling DescriptorPool.Add() multiple times with the same file only
+ # works if the input is canonical; in particular, all type names must be
+ # fully qualified.
self.pool = self.GetDescriptorPool()
self.pool.Add(file_proto)
self.my_file = self.pool.FindFileByName(file_proto.name)
self.my_message = self.my_file.message_types_by_name[message_proto.name]
self.my_enum = self.my_message.enum_types_by_name[enum_proto.name]
-
- self.my_method = descriptor.MethodDescriptor(
- name='Bar',
- full_name='protobuf_unittest.TestService.Bar',
- index=0,
- containing_service=None,
- input_type=None,
- output_type=None)
- self.my_service = descriptor.ServiceDescriptor(
- name='TestServiceWithOptions',
- full_name='protobuf_unittest.TestServiceWithOptions',
- file=self.my_file,
- index=0,
- methods=[
- self.my_method
- ])
+ self.my_service = self.my_file.services_by_name[service_proto.name]
+ self.my_method = self.my_service.methods_by_name[method_proto.name]
def GetDescriptorPool(self):
return symbol_database.Default().pool
@@ -139,13 +136,14 @@ class DescriptorTest(unittest.TestCase):
file_descriptor = unittest_custom_options_pb2.DESCRIPTOR
message_descriptor =\
unittest_custom_options_pb2.TestMessageWithCustomOptions.DESCRIPTOR
- field_descriptor = message_descriptor.fields_by_name["field1"]
- enum_descriptor = message_descriptor.enum_types_by_name["AnEnum"]
+ field_descriptor = message_descriptor.fields_by_name['field1']
+ oneof_descriptor = message_descriptor.oneofs_by_name['AnOneof']
+ enum_descriptor = message_descriptor.enum_types_by_name['AnEnum']
enum_value_descriptor =\
- message_descriptor.enum_values_by_name["ANENUM_VAL2"]
+ message_descriptor.enum_values_by_name['ANENUM_VAL2']
service_descriptor =\
unittest_custom_options_pb2.TestServiceWithCustomOptions.DESCRIPTOR
- method_descriptor = service_descriptor.FindMethodByName("Foo")
+ method_descriptor = service_descriptor.FindMethodByName('Foo')
file_options = file_descriptor.GetOptions()
file_opt1 = unittest_custom_options_pb2.file_opt1
@@ -158,6 +156,9 @@ class DescriptorTest(unittest.TestCase):
self.assertEqual(8765432109, field_options.Extensions[field_opt1])
field_opt2 = unittest_custom_options_pb2.field_opt2
self.assertEqual(42, field_options.Extensions[field_opt2])
+ oneof_options = oneof_descriptor.GetOptions()
+ oneof_opt1 = unittest_custom_options_pb2.oneof_opt1
+ self.assertEqual(-99, oneof_options.Extensions[oneof_opt1])
enum_options = enum_descriptor.GetOptions()
enum_opt1 = unittest_custom_options_pb2.enum_opt1
self.assertEqual(-789, enum_options.Extensions[enum_opt1])
diff --git a/python/google/protobuf/internal/file_options_test.proto b/python/google/protobuf/internal/file_options_test.proto
new file mode 100644
index 00000000..4eceeb07
--- /dev/null
+++ b/python/google/protobuf/internal/file_options_test.proto
@@ -0,0 +1,43 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+syntax = "proto2";
+
+import "google/protobuf/descriptor.proto";
+
+package google.protobuf.python.internal;
+
+message FooOptions {
+ optional string foo_name = 1;
+}
+
+extend .google.protobuf.FileOptions {
+ optional FooOptions foo_options = 120436268;
+}
diff --git a/python/google/protobuf/internal/json_format_test.py b/python/google/protobuf/internal/json_format_test.py
index 9e32ea47..6df12bea 100644
--- a/python/google/protobuf/internal/json_format_test.py
+++ b/python/google/protobuf/internal/json_format_test.py
@@ -643,6 +643,19 @@ class JsonFormatTest(JsonFormatBase):
'Message type "proto3.TestMessage" has no field named '
'"unknownName".')
+ def testIgnoreUnknownField(self):
+ text = '{"unknownName": 1}'
+ parsed_message = json_format_proto3_pb2.TestMessage()
+ json_format.Parse(text, parsed_message, ignore_unknown_fields=True)
+ text = ('{\n'
+ ' "repeatedValue": [ {\n'
+ ' "@type": "type.googleapis.com/proto3.MessageType",\n'
+ ' "unknownName": 1\n'
+ ' }]\n'
+ '}\n')
+ parsed_message = json_format_proto3_pb2.TestAny()
+ json_format.Parse(text, parsed_message, ignore_unknown_fields=True)
+
def testDuplicateField(self):
# Duplicate key check is not supported for python2.6
if sys.version_info < (2, 7):
diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py
index 4ee31d8e..1e95adf9 100755
--- a/python/google/protobuf/internal/message_test.py
+++ b/python/google/protobuf/internal/message_test.py
@@ -1435,6 +1435,8 @@ class Proto3Test(unittest.TestCase):
msg2.map_int32_int32[12] = 55
msg2.map_int64_int64[88] = 99
msg2.map_int32_foreign_message[222].c = 15
+ msg2.map_int32_foreign_message[222].d = 20
+ old_map_value = msg2.map_int32_foreign_message[222]
msg2.MergeFrom(msg)
@@ -1444,6 +1446,8 @@ class Proto3Test(unittest.TestCase):
self.assertEqual(99, msg2.map_int64_int64[88])
self.assertEqual(5, msg2.map_int32_foreign_message[111].c)
self.assertEqual(10, msg2.map_int32_foreign_message[222].c)
+ self.assertFalse(msg2.map_int32_foreign_message[222].HasField('d'))
+ self.assertEqual(15, old_map_value.c)
# Verify that there is only one entry per key, even though the MergeFrom
# may have internally created multiple entries for a single key in the
diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py
index ab2bf05b..0e38e0e9 100755
--- a/python/google/protobuf/internal/text_format_test.py
+++ b/python/google/protobuf/internal/text_format_test.py
@@ -40,12 +40,13 @@ import six
import string
try:
- import unittest2 as unittest #PY26
+ import unittest2 as unittest # PY26, pylint: disable=g-import-not-at-top
except ImportError:
- import unittest
+ import unittest # pylint: disable=g-import-not-at-top
from google.protobuf.internal import _parameterized
+from google.protobuf import any_test_pb2
from google.protobuf import map_unittest_pb2
from google.protobuf import unittest_mset_pb2
from google.protobuf import unittest_pb2
@@ -53,6 +54,7 @@ from google.protobuf import unittest_proto3_arena_pb2
from google.protobuf.internal import api_implementation
from google.protobuf.internal import test_util
from google.protobuf.internal import message_set_extensions_pb2
+from google.protobuf import descriptor_pool
from google.protobuf import text_format
@@ -90,13 +92,11 @@ class TextFormatBase(unittest.TestCase):
.replace('e-0','e-').replace('e-0','e-')
# Floating point fields are printed with .0 suffix even if they are
# actualy integer numbers.
- text = re.compile('\.0$', re.MULTILINE).sub('', text)
+ text = re.compile(r'\.0$', re.MULTILINE).sub('', text)
return text
-@_parameterized.Parameters(
- (unittest_pb2),
- (unittest_proto3_arena_pb2))
+@_parameterized.Parameters((unittest_pb2), (unittest_proto3_arena_pb2))
class TextFormatTest(TextFormatBase):
def testPrintExotic(self, message_module):
@@ -120,8 +120,10 @@ class TextFormatTest(TextFormatBase):
'repeated_string: "\\303\\274\\352\\234\\237"\n')
def testPrintExoticUnicodeSubclass(self, message_module):
+
class UnicodeSub(six.text_type):
pass
+
message = message_module.TestAllTypes()
message.repeated_string.append(UnicodeSub(u'\u00fc\ua71f'))
self.CompareToGoldenText(
@@ -165,8 +167,8 @@ class TextFormatTest(TextFormatBase):
message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'"')
message.repeated_string.append(u'\u00fc\ua71f')
self.CompareToGoldenText(
- self.RemoveRedundantZeros(
- text_format.MessageToString(message, as_one_line=True)),
+ self.RemoveRedundantZeros(text_format.MessageToString(
+ message, as_one_line=True)),
'repeated_int64: -9223372036854775808'
' repeated_uint64: 18446744073709551615'
' repeated_double: 123.456'
@@ -187,21 +189,23 @@ class TextFormatTest(TextFormatBase):
message.repeated_string.append(u'\u00fc\ua71f')
# Test as_utf8 = False.
- wire_text = text_format.MessageToString(
- message, as_one_line=True, as_utf8=False)
+ wire_text = text_format.MessageToString(message,
+ as_one_line=True,
+ as_utf8=False)
parsed_message = message_module.TestAllTypes()
r = text_format.Parse(wire_text, parsed_message)
self.assertIs(r, parsed_message)
self.assertEqual(message, parsed_message)
# Test as_utf8 = True.
- wire_text = text_format.MessageToString(
- message, as_one_line=True, as_utf8=True)
+ wire_text = text_format.MessageToString(message,
+ as_one_line=True,
+ as_utf8=True)
parsed_message = message_module.TestAllTypes()
r = text_format.Parse(wire_text, parsed_message)
self.assertIs(r, parsed_message)
self.assertEqual(message, parsed_message,
- '\n%s != %s' % (message, parsed_message))
+ '\n%s != %s' % (message, parsed_message))
def testPrintRawUtf8String(self, message_module):
message = message_module.TestAllTypes()
@@ -211,7 +215,7 @@ class TextFormatTest(TextFormatBase):
parsed_message = message_module.TestAllTypes()
text_format.Parse(text, parsed_message)
self.assertEqual(message, parsed_message,
- '\n%s != %s' % (message, parsed_message))
+ '\n%s != %s' % (message, parsed_message))
def testPrintFloatFormat(self, message_module):
# Check that float_format argument is passed to sub-message formatting.
@@ -232,14 +236,15 @@ class TextFormatTest(TextFormatBase):
message.payload.repeated_double.append(.000078900)
formatted_fields = ['optional_float: 1.25',
'optional_double: -3.45678901234568e-6',
- 'repeated_float: -5642',
- 'repeated_double: 7.89e-5']
+ 'repeated_float: -5642', 'repeated_double: 7.89e-5']
text_message = text_format.MessageToString(message, float_format='.15g')
self.CompareToGoldenText(
self.RemoveRedundantZeros(text_message),
- 'payload {{\n {0}\n {1}\n {2}\n {3}\n}}\n'.format(*formatted_fields))
+ 'payload {{\n {0}\n {1}\n {2}\n {3}\n}}\n'.format(
+ *formatted_fields))
# as_one_line=True is a separate code branch where float_format is passed.
- text_message = text_format.MessageToString(message, as_one_line=True,
+ text_message = text_format.MessageToString(message,
+ as_one_line=True,
float_format='.15g')
self.CompareToGoldenText(
self.RemoveRedundantZeros(text_message),
@@ -311,8 +316,7 @@ class TextFormatTest(TextFormatBase):
self.assertEqual(123.456, message.repeated_double[0])
self.assertEqual(1.23e22, message.repeated_double[1])
self.assertEqual(1.23e-18, message.repeated_double[2])
- self.assertEqual(
- '\000\001\a\b\f\n\r\t\v\\\'"', message.repeated_string[0])
+ self.assertEqual('\000\001\a\b\f\n\r\t\v\\\'"', message.repeated_string[0])
self.assertEqual('foocorgegrault', message.repeated_string[1])
self.assertEqual(u'\u00fc\ua71f', message.repeated_string[2])
self.assertEqual(u'\u00fc', message.repeated_string[3])
@@ -371,45 +375,38 @@ class TextFormatTest(TextFormatBase):
def testParseSingleWord(self, message_module):
message = message_module.TestAllTypes()
text = 'foo'
- six.assertRaisesRegex(self,
- text_format.ParseError,
- (r'1:1 : Message type "\w+.TestAllTypes" has no field named '
- r'"foo".'),
- text_format.Parse, text, message)
+ six.assertRaisesRegex(self, text_format.ParseError, (
+ r'1:1 : Message type "\w+.TestAllTypes" has no field named '
+ r'"foo".'), text_format.Parse, text, message)
def testParseUnknownField(self, message_module):
message = message_module.TestAllTypes()
text = 'unknown_field: 8\n'
- six.assertRaisesRegex(self,
- text_format.ParseError,
- (r'1:1 : Message type "\w+.TestAllTypes" has no field named '
- r'"unknown_field".'),
- text_format.Parse, text, message)
+ six.assertRaisesRegex(self, text_format.ParseError, (
+ r'1:1 : Message type "\w+.TestAllTypes" has no field named '
+ r'"unknown_field".'), text_format.Parse, text, message)
def testParseBadEnumValue(self, message_module):
message = message_module.TestAllTypes()
text = 'optional_nested_enum: BARR'
- six.assertRaisesRegex(self,
- text_format.ParseError,
- (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" '
- r'has no value named BARR.'),
- text_format.Parse, text, message)
+ six.assertRaisesRegex(self, text_format.ParseError,
+ (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" '
+ r'has no value named BARR.'), text_format.Parse,
+ text, message)
message = message_module.TestAllTypes()
text = 'optional_nested_enum: 100'
- six.assertRaisesRegex(self,
- text_format.ParseError,
- (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" '
- r'has no value with number 100.'),
- text_format.Parse, text, message)
+ six.assertRaisesRegex(self, text_format.ParseError,
+ (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" '
+ r'has no value with number 100.'), text_format.Parse,
+ text, message)
def testParseBadIntValue(self, message_module):
message = message_module.TestAllTypes()
text = 'optional_int32: bork'
- six.assertRaisesRegex(self,
- text_format.ParseError,
- ('1:17 : Couldn\'t parse integer: bork'),
- text_format.Parse, text, message)
+ six.assertRaisesRegex(self, text_format.ParseError,
+ ('1:17 : Couldn\'t parse integer: bork'),
+ text_format.Parse, text, message)
def testParseStringFieldUnescape(self, message_module):
message = message_module.TestAllTypes()
@@ -419,6 +416,7 @@ class TextFormatTest(TextFormatBase):
repeated_string: "\\\\xf\\\\x62"
repeated_string: "\\\\\xf\\\\\x62"
repeated_string: "\x5cx20"'''
+
text_format.Parse(text, message)
SLASH = '\\'
@@ -433,8 +431,7 @@ class TextFormatTest(TextFormatBase):
def testMergeDuplicateScalars(self, message_module):
message = message_module.TestAllTypes()
- text = ('optional_int32: 42 '
- 'optional_int32: 67')
+ text = ('optional_int32: 42 ' 'optional_int32: 67')
r = text_format.Merge(text, message)
self.assertIs(r, message)
self.assertEqual(67, message.optional_int32)
@@ -455,13 +452,11 @@ class TextFormatTest(TextFormatBase):
self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
def testParseMultipleOneof(self, message_module):
- m_string = '\n'.join([
- 'oneof_uint32: 11',
- 'oneof_string: "foo"'])
+ m_string = '\n'.join(['oneof_uint32: 11', 'oneof_string: "foo"'])
m2 = message_module.TestAllTypes()
if message_module is unittest_pb2:
- with self.assertRaisesRegexp(
- text_format.ParseError, ' is specified along with field '):
+ with self.assertRaisesRegexp(text_format.ParseError,
+ ' is specified along with field '):
text_format.Parse(m_string, m2)
else:
text_format.Parse(m_string, m2)
@@ -477,8 +472,8 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase):
message = unittest_pb2.TestAllTypes()
test_util.SetAllFields(message)
self.CompareToGoldenFile(
- self.RemoveRedundantZeros(
- text_format.MessageToString(message, pointy_brackets=True)),
+ self.RemoveRedundantZeros(text_format.MessageToString(
+ message, pointy_brackets=True)),
'text_format_unittest_data_pointy_oneof.txt')
def testParseGolden(self):
@@ -499,14 +494,6 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase):
self.RemoveRedundantZeros(text_format.MessageToString(message)),
'text_format_unittest_data_oneof_implemented.txt')
- def testPrintAllFieldsPointy(self):
- message = unittest_pb2.TestAllTypes()
- test_util.SetAllFields(message)
- self.CompareToGoldenFile(
- self.RemoveRedundantZeros(
- text_format.MessageToString(message, pointy_brackets=True)),
- 'text_format_unittest_data_pointy_oneof.txt')
-
def testPrintInIndexOrder(self):
message = unittest_pb2.TestFieldOrderings()
message.my_string = '115'
@@ -520,8 +507,7 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase):
'my_string: \"115\"\nmy_int: 101\nmy_float: 111\n'
'optional_nested_message {\n oo: 0\n bb: 1\n}\n')
self.CompareToGoldenText(
- self.RemoveRedundantZeros(text_format.MessageToString(
- message)),
+ self.RemoveRedundantZeros(text_format.MessageToString(message)),
'my_int: 101\nmy_string: \"115\"\nmy_float: 111\n'
'optional_nested_message {\n bb: 1\n oo: 0\n}\n')
@@ -552,14 +538,13 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase):
message.map_int64_int64[-2**33] = -2**34
message.map_uint32_uint32[123] = 456
message.map_uint64_uint64[2**33] = 2**34
- message.map_string_string["abc"] = "123"
+ message.map_string_string['abc'] = '123'
message.map_int32_foreign_message[111].c = 5
# Maps are serialized to text format using their underlying repeated
# representation.
self.CompareToGoldenText(
- text_format.MessageToString(message),
- 'map_int32_int32 {\n'
+ text_format.MessageToString(message), 'map_int32_int32 {\n'
' key: -123\n'
' value: -456\n'
'}\n'
@@ -592,9 +577,8 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase):
message.map_string_string[letter] = 'dummy'
for letter in reversed(string.ascii_uppercase[0:13]):
message.map_string_string[letter] = 'dummy'
- golden = ''.join((
- 'map_string_string {\n key: "%c"\n value: "dummy"\n}\n' % (letter,)
- for letter in string.ascii_uppercase))
+ golden = ''.join(('map_string_string {\n key: "%c"\n value: "dummy"\n}\n'
+ % (letter,) for letter in string.ascii_uppercase))
self.CompareToGoldenText(text_format.MessageToString(message), golden)
def testMapOrderSemantics(self):
@@ -602,9 +586,7 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase):
# The C++ implementation emits defaulted-value fields, while the Python
# implementation does not. Adjusting for this is awkward, but it is
# valuable to test against a common golden file.
- line_blacklist = (' key: 0\n',
- ' value: 0\n',
- ' key: false\n',
+ line_blacklist = (' key: 0\n', ' value: 0\n', ' key: false\n',
' value: false\n')
golden_lines = [line for line in golden_lines if line not in line_blacklist]
@@ -627,8 +609,7 @@ class Proto2Tests(TextFormatBase):
message.message_set.Extensions[ext1].i = 23
message.message_set.Extensions[ext2].str = 'foo'
self.CompareToGoldenText(
- text_format.MessageToString(message),
- 'message_set {\n'
+ text_format.MessageToString(message), 'message_set {\n'
' [protobuf_unittest.TestMessageSetExtension1] {\n'
' i: 23\n'
' }\n'
@@ -654,16 +635,14 @@ class Proto2Tests(TextFormatBase):
message.message_set.Extensions[ext1].i = 23
message.message_set.Extensions[ext2].str = 'foo'
text_format.PrintMessage(message, out, use_field_number=True)
- self.CompareToGoldenText(
- out.getvalue(),
- '1 {\n'
- ' 1545008 {\n'
- ' 15: 23\n'
- ' }\n'
- ' 1547769 {\n'
- ' 25: \"foo\"\n'
- ' }\n'
- '}\n')
+ self.CompareToGoldenText(out.getvalue(), '1 {\n'
+ ' 1545008 {\n'
+ ' 15: 23\n'
+ ' }\n'
+ ' 1547769 {\n'
+ ' 25: \"foo\"\n'
+ ' }\n'
+ '}\n')
out.close()
def testPrintMessageSetAsOneLine(self):
@@ -685,8 +664,7 @@ class Proto2Tests(TextFormatBase):
def testParseMessageSet(self):
message = unittest_pb2.TestAllTypes()
- text = ('repeated_uint64: 1\n'
- 'repeated_uint64: 2\n')
+ text = ('repeated_uint64: 1\n' 'repeated_uint64: 2\n')
text_format.Parse(text, message)
self.assertEqual(1, message.repeated_uint64[0])
self.assertEqual(2, message.repeated_uint64[1])
@@ -708,8 +686,7 @@ class Proto2Tests(TextFormatBase):
def testParseMessageByFieldNumber(self):
message = unittest_pb2.TestAllTypes()
- text = ('34: 1\n'
- 'repeated_uint64: 2\n')
+ text = ('34: 1\n' 'repeated_uint64: 2\n')
text_format.Parse(text, message, allow_field_number=True)
self.assertEqual(1, message.repeated_uint64[0])
self.assertEqual(2, message.repeated_uint64[1])
@@ -732,12 +709,9 @@ class Proto2Tests(TextFormatBase):
# Can't parse field number without set allow_field_number=True.
message = unittest_pb2.TestAllTypes()
text = '34:1\n'
- six.assertRaisesRegex(
- self,
- text_format.ParseError,
- (r'1:1 : Message type "\w+.TestAllTypes" has no field named '
- r'"34".'),
- text_format.Parse, text, message)
+ six.assertRaisesRegex(self, text_format.ParseError, (
+ r'1:1 : Message type "\w+.TestAllTypes" has no field named '
+ r'"34".'), text_format.Parse, text, message)
# Can't parse if field number is not found.
text = '1234:1\n'
@@ -746,7 +720,10 @@ class Proto2Tests(TextFormatBase):
text_format.ParseError,
(r'1:1 : Message type "\w+.TestAllTypes" has no field named '
r'"1234".'),
- text_format.Parse, text, message, allow_field_number=True)
+ text_format.Parse,
+ text,
+ message,
+ allow_field_number=True)
def testPrintAllExtensions(self):
message = unittest_pb2.TestAllExtensions()
@@ -824,7 +801,9 @@ class Proto2Tests(TextFormatBase):
six.assertRaisesRegex(self,
text_format.ParseError,
'Invalid field value: }',
- text_format.Parse, malformed, message,
+ text_format.Parse,
+ malformed,
+ message,
allow_unknown_extension=True)
message = unittest_mset_pb2.TestMessageSetContainer()
@@ -836,7 +815,9 @@ class Proto2Tests(TextFormatBase):
six.assertRaisesRegex(self,
text_format.ParseError,
'Invalid field value: "',
- text_format.Parse, malformed, message,
+ text_format.Parse,
+ malformed,
+ message,
allow_unknown_extension=True)
message = unittest_mset_pb2.TestMessageSetContainer()
@@ -848,7 +829,9 @@ class Proto2Tests(TextFormatBase):
six.assertRaisesRegex(self,
text_format.ParseError,
'Invalid field value: "',
- text_format.Parse, malformed, message,
+ text_format.Parse,
+ malformed,
+ message,
allow_unknown_extension=True)
message = unittest_mset_pb2.TestMessageSetContainer()
@@ -860,7 +843,9 @@ class Proto2Tests(TextFormatBase):
six.assertRaisesRegex(self,
text_format.ParseError,
'5:1 : Expected ">".',
- text_format.Parse, malformed, message,
+ text_format.Parse,
+ malformed,
+ message,
allow_unknown_extension=True)
# Don't allow unknown fields with allow_unknown_extension=True.
@@ -874,7 +859,9 @@ class Proto2Tests(TextFormatBase):
('2:3 : Message type '
'"proto2_wireformat_unittest.TestMessageSet" has no'
' field named "unknown_field".'),
- text_format.Parse, malformed, message,
+ text_format.Parse,
+ malformed,
+ message,
allow_unknown_extension=True)
# Parse known extension correcty.
@@ -896,67 +883,57 @@ class Proto2Tests(TextFormatBase):
def testParseBadExtension(self):
message = unittest_pb2.TestAllExtensions()
text = '[unknown_extension]: 8\n'
- six.assertRaisesRegex(self,
- text_format.ParseError,
- '1:2 : Extension "unknown_extension" not registered.',
- text_format.Parse, text, message)
+ six.assertRaisesRegex(self, text_format.ParseError,
+ '1:2 : Extension "unknown_extension" not registered.',
+ text_format.Parse, text, message)
message = unittest_pb2.TestAllTypes()
- six.assertRaisesRegex(self,
- text_format.ParseError,
- ('1:2 : Message type "protobuf_unittest.TestAllTypes" does not have '
- 'extensions.'),
- text_format.Parse, text, message)
+ six.assertRaisesRegex(self, text_format.ParseError, (
+ '1:2 : Message type "protobuf_unittest.TestAllTypes" does not have '
+ 'extensions.'), text_format.Parse, text, message)
def testMergeDuplicateExtensionScalars(self):
message = unittest_pb2.TestAllExtensions()
text = ('[protobuf_unittest.optional_int32_extension]: 42 '
'[protobuf_unittest.optional_int32_extension]: 67')
text_format.Merge(text, message)
- self.assertEqual(
- 67,
- message.Extensions[unittest_pb2.optional_int32_extension])
+ self.assertEqual(67,
+ message.Extensions[unittest_pb2.optional_int32_extension])
def testParseDuplicateExtensionScalars(self):
message = unittest_pb2.TestAllExtensions()
text = ('[protobuf_unittest.optional_int32_extension]: 42 '
'[protobuf_unittest.optional_int32_extension]: 67')
- six.assertRaisesRegex(self,
- text_format.ParseError,
- ('1:96 : Message type "protobuf_unittest.TestAllExtensions" '
- 'should not have multiple '
- '"protobuf_unittest.optional_int32_extension" extensions.'),
- text_format.Parse, text, message)
+ six.assertRaisesRegex(self, text_format.ParseError, (
+ '1:96 : Message type "protobuf_unittest.TestAllExtensions" '
+ 'should not have multiple '
+ '"protobuf_unittest.optional_int32_extension" extensions.'),
+ text_format.Parse, text, message)
def testParseDuplicateNestedMessageScalars(self):
message = unittest_pb2.TestAllTypes()
text = ('optional_nested_message { bb: 1 } '
'optional_nested_message { bb: 2 }')
- six.assertRaisesRegex(self,
- text_format.ParseError,
- ('1:65 : Message type "protobuf_unittest.TestAllTypes.NestedMessage" '
- 'should not have multiple "bb" fields.'),
- text_format.Parse, text, message)
+ six.assertRaisesRegex(self, text_format.ParseError, (
+ '1:65 : Message type "protobuf_unittest.TestAllTypes.NestedMessage" '
+ 'should not have multiple "bb" fields.'), text_format.Parse, text,
+ message)
def testParseDuplicateScalars(self):
message = unittest_pb2.TestAllTypes()
- text = ('optional_int32: 42 '
- 'optional_int32: 67')
- six.assertRaisesRegex(self,
- text_format.ParseError,
- ('1:36 : Message type "protobuf_unittest.TestAllTypes" should not '
- 'have multiple "optional_int32" fields.'),
- text_format.Parse, text, message)
+ text = ('optional_int32: 42 ' 'optional_int32: 67')
+ six.assertRaisesRegex(self, text_format.ParseError, (
+ '1:36 : Message type "protobuf_unittest.TestAllTypes" should not '
+ 'have multiple "optional_int32" fields.'), text_format.Parse, text,
+ message)
def testParseGroupNotClosed(self):
message = unittest_pb2.TestAllTypes()
text = 'RepeatedGroup: <'
- six.assertRaisesRegex(self,
- text_format.ParseError, '1:16 : Expected ">".',
- text_format.Parse, text, message)
+ six.assertRaisesRegex(self, text_format.ParseError, '1:16 : Expected ">".',
+ text_format.Parse, text, message)
text = 'RepeatedGroup: {'
- six.assertRaisesRegex(self,
- text_format.ParseError, '1:16 : Expected "}".',
- text_format.Parse, text, message)
+ six.assertRaisesRegex(self, text_format.ParseError, '1:16 : Expected "}".',
+ text_format.Parse, text, message)
def testParseEmptyGroup(self):
message = unittest_pb2.TestAllTypes()
@@ -1007,10 +984,197 @@ class Proto2Tests(TextFormatBase):
self.assertEqual(-2**34, message.map_int64_int64[-2**33])
self.assertEqual(456, message.map_uint32_uint32[123])
self.assertEqual(2**34, message.map_uint64_uint64[2**33])
- self.assertEqual("123", message.map_string_string["abc"])
+ self.assertEqual('123', message.map_string_string['abc'])
self.assertEqual(5, message.map_int32_foreign_message[111].c)
+class Proto3Tests(unittest.TestCase):
+
+ def testPrintMessageExpandAny(self):
+ packed_message = unittest_pb2.OneString()
+ packed_message.data = 'string'
+ message = any_test_pb2.TestAny()
+ message.any_value.Pack(packed_message)
+ self.assertEqual(
+ text_format.MessageToString(message,
+ descriptor_pool=descriptor_pool.Default()),
+ 'any_value {\n'
+ ' [type.googleapis.com/protobuf_unittest.OneString] {\n'
+ ' data: "string"\n'
+ ' }\n'
+ '}\n')
+
+ def testPrintMessageExpandAnyRepeated(self):
+ packed_message = unittest_pb2.OneString()
+ message = any_test_pb2.TestAny()
+ packed_message.data = 'string0'
+ message.repeated_any_value.add().Pack(packed_message)
+ packed_message.data = 'string1'
+ message.repeated_any_value.add().Pack(packed_message)
+ self.assertEqual(
+ text_format.MessageToString(message,
+ descriptor_pool=descriptor_pool.Default()),
+ 'repeated_any_value {\n'
+ ' [type.googleapis.com/protobuf_unittest.OneString] {\n'
+ ' data: "string0"\n'
+ ' }\n'
+ '}\n'
+ 'repeated_any_value {\n'
+ ' [type.googleapis.com/protobuf_unittest.OneString] {\n'
+ ' data: "string1"\n'
+ ' }\n'
+ '}\n')
+
+ def testPrintMessageExpandAnyNoDescriptorPool(self):
+ packed_message = unittest_pb2.OneString()
+ packed_message.data = 'string'
+ message = any_test_pb2.TestAny()
+ message.any_value.Pack(packed_message)
+ self.assertEqual(
+ text_format.MessageToString(message, descriptor_pool=None),
+ 'any_value {\n'
+ ' type_url: "type.googleapis.com/protobuf_unittest.OneString"\n'
+ ' value: "\\n\\006string"\n'
+ '}\n')
+
+ def testPrintMessageExpandAnyDescriptorPoolMissingType(self):
+ packed_message = unittest_pb2.OneString()
+ packed_message.data = 'string'
+ message = any_test_pb2.TestAny()
+ message.any_value.Pack(packed_message)
+ empty_pool = descriptor_pool.DescriptorPool()
+ self.assertEqual(
+ text_format.MessageToString(message, descriptor_pool=empty_pool),
+ 'any_value {\n'
+ ' type_url: "type.googleapis.com/protobuf_unittest.OneString"\n'
+ ' value: "\\n\\006string"\n'
+ '}\n')
+
+ def testPrintMessageExpandAnyPointyBrackets(self):
+ packed_message = unittest_pb2.OneString()
+ packed_message.data = 'string'
+ message = any_test_pb2.TestAny()
+ message.any_value.Pack(packed_message)
+ self.assertEqual(
+ text_format.MessageToString(message,
+ pointy_brackets=True,
+ descriptor_pool=descriptor_pool.Default()),
+ 'any_value <\n'
+ ' [type.googleapis.com/protobuf_unittest.OneString] <\n'
+ ' data: "string"\n'
+ ' >\n'
+ '>\n')
+
+ def testPrintMessageExpandAnyAsOneLine(self):
+ packed_message = unittest_pb2.OneString()
+ packed_message.data = 'string'
+ message = any_test_pb2.TestAny()
+ message.any_value.Pack(packed_message)
+ self.assertEqual(
+ text_format.MessageToString(message,
+ as_one_line=True,
+ descriptor_pool=descriptor_pool.Default()),
+ 'any_value {'
+ ' [type.googleapis.com/protobuf_unittest.OneString]'
+ ' { data: "string" } '
+ '}')
+
+ def testPrintMessageExpandAnyAsOneLinePointyBrackets(self):
+ packed_message = unittest_pb2.OneString()
+ packed_message.data = 'string'
+ message = any_test_pb2.TestAny()
+ message.any_value.Pack(packed_message)
+ self.assertEqual(
+ text_format.MessageToString(message,
+ as_one_line=True,
+ pointy_brackets=True,
+ descriptor_pool=descriptor_pool.Default()),
+ 'any_value <'
+ ' [type.googleapis.com/protobuf_unittest.OneString]'
+ ' < data: "string" > '
+ '>')
+
+ def testMergeExpandedAny(self):
+ message = any_test_pb2.TestAny()
+ text = ('any_value {\n'
+ ' [type.googleapis.com/protobuf_unittest.OneString] {\n'
+ ' data: "string"\n'
+ ' }\n'
+ '}\n')
+ text_format.Merge(text, message, descriptor_pool=descriptor_pool.Default())
+ packed_message = unittest_pb2.OneString()
+ message.any_value.Unpack(packed_message)
+ self.assertEqual('string', packed_message.data)
+
+ def testMergeExpandedAnyRepeated(self):
+ message = any_test_pb2.TestAny()
+ text = ('repeated_any_value {\n'
+ ' [type.googleapis.com/protobuf_unittest.OneString] {\n'
+ ' data: "string0"\n'
+ ' }\n'
+ '}\n'
+ 'repeated_any_value {\n'
+ ' [type.googleapis.com/protobuf_unittest.OneString] {\n'
+ ' data: "string1"\n'
+ ' }\n'
+ '}\n')
+ text_format.Merge(text, message, descriptor_pool=descriptor_pool.Default())
+ packed_message = unittest_pb2.OneString()
+ message.repeated_any_value[0].Unpack(packed_message)
+ self.assertEqual('string0', packed_message.data)
+ message.repeated_any_value[1].Unpack(packed_message)
+ self.assertEqual('string1', packed_message.data)
+
+ def testMergeExpandedAnyPointyBrackets(self):
+ message = any_test_pb2.TestAny()
+ text = ('any_value {\n'
+ ' [type.googleapis.com/protobuf_unittest.OneString] <\n'
+ ' data: "string"\n'
+ ' >\n'
+ '}\n')
+ text_format.Merge(text, message, descriptor_pool=descriptor_pool.Default())
+ packed_message = unittest_pb2.OneString()
+ message.any_value.Unpack(packed_message)
+ self.assertEqual('string', packed_message.data)
+
+ def testMergeExpandedAnyNoDescriptorPool(self):
+ message = any_test_pb2.TestAny()
+ text = ('any_value {\n'
+ ' [type.googleapis.com/protobuf_unittest.OneString] {\n'
+ ' data: "string"\n'
+ ' }\n'
+ '}\n')
+ with self.assertRaises(text_format.ParseError) as e:
+ text_format.Merge(text, message, descriptor_pool=None)
+ self.assertEqual(str(e.exception),
+ 'Descriptor pool required to parse expanded Any field')
+
+ def testMergeExpandedAnyDescriptorPoolMissingType(self):
+ message = any_test_pb2.TestAny()
+ text = ('any_value {\n'
+ ' [type.googleapis.com/protobuf_unittest.OneString] {\n'
+ ' data: "string"\n'
+ ' }\n'
+ '}\n')
+ with self.assertRaises(text_format.ParseError) as e:
+ empty_pool = descriptor_pool.DescriptorPool()
+ text_format.Merge(text, message, descriptor_pool=empty_pool)
+ self.assertEqual(
+ str(e.exception),
+ 'Type protobuf_unittest.OneString not found in descriptor pool')
+
+ def testMergeUnexpandedAny(self):
+ text = ('any_value {\n'
+ ' type_url: "type.googleapis.com/protobuf_unittest.OneString"\n'
+ ' value: "\\n\\006string"\n'
+ '}\n')
+ message = any_test_pb2.TestAny()
+ text_format.Merge(text, message)
+ packed_message = unittest_pb2.OneString()
+ message.any_value.Unpack(packed_message)
+ self.assertEqual('string', packed_message.data)
+
+
class TokenizerTest(unittest.TestCase):
def testSimpleTokenCases(self):
@@ -1021,79 +1185,55 @@ class TokenizerTest(unittest.TestCase):
'ID9: 22 ID10: -111111111111111111 ID11: -22\n'
'ID12: 2222222222222222222 ID13: 1.23456f ID14: 1.2e+2f '
'false_bool: 0 true_BOOL:t \n true_bool1: 1 false_BOOL1:f ')
- tokenizer = text_format._Tokenizer(text.splitlines())
- methods = [(tokenizer.ConsumeIdentifier, 'identifier1'),
- ':',
+ tokenizer = text_format.Tokenizer(text.splitlines())
+ methods = [(tokenizer.ConsumeIdentifier, 'identifier1'), ':',
(tokenizer.ConsumeString, 'string1'),
- (tokenizer.ConsumeIdentifier, 'identifier2'),
- ':',
- (tokenizer.ConsumeInt32, 123),
- (tokenizer.ConsumeIdentifier, 'identifier3'),
- ':',
+ (tokenizer.ConsumeIdentifier, 'identifier2'), ':',
+ (tokenizer.ConsumeInteger, 123),
+ (tokenizer.ConsumeIdentifier, 'identifier3'), ':',
(tokenizer.ConsumeString, 'string'),
- (tokenizer.ConsumeIdentifier, 'identifiER_4'),
- ':',
+ (tokenizer.ConsumeIdentifier, 'identifiER_4'), ':',
(tokenizer.ConsumeFloat, 1.1e+2),
- (tokenizer.ConsumeIdentifier, 'ID5'),
- ':',
+ (tokenizer.ConsumeIdentifier, 'ID5'), ':',
(tokenizer.ConsumeFloat, -0.23),
- (tokenizer.ConsumeIdentifier, 'ID6'),
- ':',
+ (tokenizer.ConsumeIdentifier, 'ID6'), ':',
(tokenizer.ConsumeString, 'aaaa\'bbbb'),
- (tokenizer.ConsumeIdentifier, 'ID7'),
- ':',
+ (tokenizer.ConsumeIdentifier, 'ID7'), ':',
(tokenizer.ConsumeString, 'aa\"bb'),
- (tokenizer.ConsumeIdentifier, 'ID8'),
- ':',
- '{',
- (tokenizer.ConsumeIdentifier, 'A'),
- ':',
+ (tokenizer.ConsumeIdentifier, 'ID8'), ':', '{',
+ (tokenizer.ConsumeIdentifier, 'A'), ':',
(tokenizer.ConsumeFloat, float('inf')),
- (tokenizer.ConsumeIdentifier, 'B'),
- ':',
+ (tokenizer.ConsumeIdentifier, 'B'), ':',
(tokenizer.ConsumeFloat, -float('inf')),
- (tokenizer.ConsumeIdentifier, 'C'),
- ':',
+ (tokenizer.ConsumeIdentifier, 'C'), ':',
(tokenizer.ConsumeBool, True),
- (tokenizer.ConsumeIdentifier, 'D'),
- ':',
- (tokenizer.ConsumeBool, False),
- '}',
- (tokenizer.ConsumeIdentifier, 'ID9'),
- ':',
- (tokenizer.ConsumeUint32, 22),
- (tokenizer.ConsumeIdentifier, 'ID10'),
- ':',
- (tokenizer.ConsumeInt64, -111111111111111111),
- (tokenizer.ConsumeIdentifier, 'ID11'),
- ':',
- (tokenizer.ConsumeInt32, -22),
- (tokenizer.ConsumeIdentifier, 'ID12'),
- ':',
- (tokenizer.ConsumeUint64, 2222222222222222222),
- (tokenizer.ConsumeIdentifier, 'ID13'),
- ':',
+ (tokenizer.ConsumeIdentifier, 'D'), ':',
+ (tokenizer.ConsumeBool, False), '}',
+ (tokenizer.ConsumeIdentifier, 'ID9'), ':',
+ (tokenizer.ConsumeInteger, 22),
+ (tokenizer.ConsumeIdentifier, 'ID10'), ':',
+ (tokenizer.ConsumeInteger, -111111111111111111),
+ (tokenizer.ConsumeIdentifier, 'ID11'), ':',
+ (tokenizer.ConsumeInteger, -22),
+ (tokenizer.ConsumeIdentifier, 'ID12'), ':',
+ (tokenizer.ConsumeInteger, 2222222222222222222),
+ (tokenizer.ConsumeIdentifier, 'ID13'), ':',
(tokenizer.ConsumeFloat, 1.23456),
- (tokenizer.ConsumeIdentifier, 'ID14'),
- ':',
+ (tokenizer.ConsumeIdentifier, 'ID14'), ':',
(tokenizer.ConsumeFloat, 1.2e+2),
- (tokenizer.ConsumeIdentifier, 'false_bool'),
- ':',
+ (tokenizer.ConsumeIdentifier, 'false_bool'), ':',
(tokenizer.ConsumeBool, False),
- (tokenizer.ConsumeIdentifier, 'true_BOOL'),
- ':',
+ (tokenizer.ConsumeIdentifier, 'true_BOOL'), ':',
(tokenizer.ConsumeBool, True),
- (tokenizer.ConsumeIdentifier, 'true_bool1'),
- ':',
+ (tokenizer.ConsumeIdentifier, 'true_bool1'), ':',
(tokenizer.ConsumeBool, True),
- (tokenizer.ConsumeIdentifier, 'false_BOOL1'),
- ':',
+ (tokenizer.ConsumeIdentifier, 'false_BOOL1'), ':',
(tokenizer.ConsumeBool, False)]
i = 0
while not tokenizer.AtEnd():
m = methods[i]
- if type(m) == str:
+ if isinstance(m, str):
token = tokenizer.token
self.assertEqual(token, m)
tokenizer.NextToken()
@@ -1101,59 +1241,119 @@ class TokenizerTest(unittest.TestCase):
self.assertEqual(m[1], m[0]())
i += 1
- def testConsumeIntegers(self):
+ def testConsumeAbstractIntegers(self):
# This test only tests the failures in the integer parsing methods as well
# as the '0' special cases.
int64_max = (1 << 63) - 1
uint32_max = (1 << 32) - 1
text = '-1 %d %d' % (uint32_max + 1, int64_max + 1)
- tokenizer = text_format._Tokenizer(text.splitlines())
- self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint32)
- self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint64)
- self.assertEqual(-1, tokenizer.ConsumeInt32())
+ tokenizer = text_format.Tokenizer(text.splitlines())
+ self.assertEqual(-1, tokenizer.ConsumeInteger())
- self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint32)
- self.assertRaises(text_format.ParseError, tokenizer.ConsumeInt32)
- self.assertEqual(uint32_max + 1, tokenizer.ConsumeInt64())
+ self.assertEqual(uint32_max + 1, tokenizer.ConsumeInteger())
- self.assertRaises(text_format.ParseError, tokenizer.ConsumeInt64)
- self.assertEqual(int64_max + 1, tokenizer.ConsumeUint64())
+ self.assertEqual(int64_max + 1, tokenizer.ConsumeInteger())
+ self.assertTrue(tokenizer.AtEnd())
+
+ text = '-0 0'
+ tokenizer = text_format.Tokenizer(text.splitlines())
+ self.assertEqual(0, tokenizer.ConsumeInteger())
+ self.assertEqual(0, tokenizer.ConsumeInteger())
+ self.assertTrue(tokenizer.AtEnd())
+
+ def testConsumeIntegers(self):
+ # This test only tests the failures in the integer parsing methods as well
+ # as the '0' special cases.
+ int64_max = (1 << 63) - 1
+ uint32_max = (1 << 32) - 1
+ text = '-1 %d %d' % (uint32_max + 1, int64_max + 1)
+ tokenizer = text_format.Tokenizer(text.splitlines())
+ self.assertRaises(text_format.ParseError,
+ text_format._ConsumeUint32, tokenizer)
+ self.assertRaises(text_format.ParseError,
+ text_format._ConsumeUint64, tokenizer)
+ self.assertEqual(-1, text_format._ConsumeInt32(tokenizer))
+
+ self.assertRaises(text_format.ParseError,
+ text_format._ConsumeUint32, tokenizer)
+ self.assertRaises(text_format.ParseError,
+ text_format._ConsumeInt32, tokenizer)
+ self.assertEqual(uint32_max + 1, text_format._ConsumeInt64(tokenizer))
+
+ self.assertRaises(text_format.ParseError,
+ text_format._ConsumeInt64, tokenizer)
+ self.assertEqual(int64_max + 1, text_format._ConsumeUint64(tokenizer))
self.assertTrue(tokenizer.AtEnd())
text = '-0 -0 0 0'
- tokenizer = text_format._Tokenizer(text.splitlines())
- self.assertEqual(0, tokenizer.ConsumeUint32())
- self.assertEqual(0, tokenizer.ConsumeUint64())
- self.assertEqual(0, tokenizer.ConsumeUint32())
- self.assertEqual(0, tokenizer.ConsumeUint64())
+ tokenizer = text_format.Tokenizer(text.splitlines())
+ self.assertEqual(0, text_format._ConsumeUint32(tokenizer))
+ self.assertEqual(0, text_format._ConsumeUint64(tokenizer))
+ self.assertEqual(0, text_format._ConsumeUint32(tokenizer))
+ self.assertEqual(0, text_format._ConsumeUint64(tokenizer))
self.assertTrue(tokenizer.AtEnd())
def testConsumeByteString(self):
text = '"string1\''
- tokenizer = text_format._Tokenizer(text.splitlines())
+ tokenizer = text_format.Tokenizer(text.splitlines())
self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
text = 'string1"'
- tokenizer = text_format._Tokenizer(text.splitlines())
+ tokenizer = text_format.Tokenizer(text.splitlines())
self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
text = '\n"\\xt"'
- tokenizer = text_format._Tokenizer(text.splitlines())
+ tokenizer = text_format.Tokenizer(text.splitlines())
self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
text = '\n"\\"'
- tokenizer = text_format._Tokenizer(text.splitlines())
+ tokenizer = text_format.Tokenizer(text.splitlines())
self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
text = '\n"\\x"'
- tokenizer = text_format._Tokenizer(text.splitlines())
+ tokenizer = text_format.Tokenizer(text.splitlines())
self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
def testConsumeBool(self):
text = 'not-a-bool'
- tokenizer = text_format._Tokenizer(text.splitlines())
+ tokenizer = text_format.Tokenizer(text.splitlines())
self.assertRaises(text_format.ParseError, tokenizer.ConsumeBool)
+ def testSkipComment(self):
+ tokenizer = text_format.Tokenizer('# some comment'.splitlines())
+ self.assertTrue(tokenizer.AtEnd())
+ self.assertRaises(text_format.ParseError, tokenizer.ConsumeComment)
+
+ def testConsumeComment(self):
+ tokenizer = text_format.Tokenizer('# some comment'.splitlines(),
+ skip_comments=False)
+ self.assertFalse(tokenizer.AtEnd())
+ self.assertEqual('# some comment', tokenizer.ConsumeComment())
+ self.assertTrue(tokenizer.AtEnd())
+
+ def testConsumeTwoComments(self):
+ text = '# some comment\n# another comment'
+ tokenizer = text_format.Tokenizer(text.splitlines(), skip_comments=False)
+ self.assertEqual('# some comment', tokenizer.ConsumeComment())
+ self.assertFalse(tokenizer.AtEnd())
+ self.assertEqual('# another comment', tokenizer.ConsumeComment())
+ self.assertTrue(tokenizer.AtEnd())
+
+ def testConsumeTrailingComment(self):
+ text = 'some_number: 4\n# some comment'
+ tokenizer = text_format.Tokenizer(text.splitlines(), skip_comments=False)
+ self.assertRaises(text_format.ParseError, tokenizer.ConsumeComment)
+
+ self.assertEqual('some_number', tokenizer.ConsumeIdentifier())
+ self.assertEqual(tokenizer.token, ':')
+ tokenizer.NextToken()
+ self.assertRaises(text_format.ParseError, tokenizer.ConsumeComment)
+ self.assertEqual(4, tokenizer.ConsumeInteger())
+ self.assertFalse(tokenizer.AtEnd())
+
+ self.assertEqual('# some comment', tokenizer.ConsumeComment())
+ self.assertTrue(tokenizer.AtEnd())
+
if __name__ == '__main__':
unittest.main()
diff --git a/python/google/protobuf/json_format.py b/python/google/protobuf/json_format.py
index be6a9b63..bb6a1998 100644
--- a/python/google/protobuf/json_format.py
+++ b/python/google/protobuf/json_format.py
@@ -53,6 +53,7 @@ import re
import six
import sys
+from operator import methodcaller
from google.protobuf import descriptor
from google.protobuf import symbol_database
@@ -98,22 +99,8 @@ def MessageToJson(message, including_default_value_fields=False):
Returns:
A string containing the JSON formatted protocol buffer message.
"""
- js = _MessageToJsonObject(message, including_default_value_fields)
- return json.dumps(js, indent=2)
-
-
-def _MessageToJsonObject(message, including_default_value_fields):
- """Converts message to an object according to Proto3 JSON Specification."""
- message_descriptor = message.DESCRIPTOR
- full_name = message_descriptor.full_name
- if _IsWrapperMessage(message_descriptor):
- return _WrapperMessageToJsonObject(message)
- if full_name in _WKTJSONMETHODS:
- return _WKTJSONMETHODS[full_name][0](
- message, including_default_value_fields)
- js = {}
- return _RegularMessageToJsonObject(
- message, js, including_default_value_fields)
+ printer = _Printer(including_default_value_fields)
+ return printer.ToJsonString(message)
def _IsMapEntry(field):
@@ -122,179 +109,179 @@ def _IsMapEntry(field):
field.message_type.GetOptions().map_entry)
-def _RegularMessageToJsonObject(message, js, including_default_value_fields):
- """Converts normal message according to Proto3 JSON Specification."""
- fields = message.ListFields()
- include_default = including_default_value_fields
+class _Printer(object):
+ """JSON format printer for protocol message."""
- try:
- for field, value in fields:
- name = field.camelcase_name
- if _IsMapEntry(field):
- # Convert a map field.
- v_field = field.message_type.fields_by_name['value']
- js_map = {}
- for key in value:
- if isinstance(key, bool):
- if key:
- recorded_key = 'true'
- else:
- recorded_key = 'false'
- else:
- recorded_key = key
- js_map[recorded_key] = _FieldToJsonObject(
- v_field, value[key], including_default_value_fields)
- js[name] = js_map
- elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
- # Convert a repeated field.
- js[name] = [_FieldToJsonObject(field, k, include_default)
- for k in value]
- else:
- js[name] = _FieldToJsonObject(field, value, include_default)
-
- # Serialize default value if including_default_value_fields is True.
- if including_default_value_fields:
- message_descriptor = message.DESCRIPTOR
- for field in message_descriptor.fields:
- # Singular message fields and oneof fields will not be affected.
- if ((field.label != descriptor.FieldDescriptor.LABEL_REPEATED and
- field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE) or
- field.containing_oneof):
- continue
- name = field.camelcase_name
- if name in js:
- # Skip the field which has been serailized already.
- continue
- if _IsMapEntry(field):
- js[name] = {}
- elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
- js[name] = []
- else:
- js[name] = _FieldToJsonObject(field, field.default_value)
+ def __init__(self,
+ including_default_value_fields=False):
+ self.including_default_value_fields = including_default_value_fields
- except ValueError as e:
- raise SerializeToJsonError(
- 'Failed to serialize {0} field: {1}.'.format(field.name, e))
+ def ToJsonString(self, message):
+ js = self._MessageToJsonObject(message)
+ return json.dumps(js, indent=2)
- return js
+ def _MessageToJsonObject(self, message):
+ """Converts message to an object according to Proto3 JSON Specification."""
+ message_descriptor = message.DESCRIPTOR
+ full_name = message_descriptor.full_name
+ if _IsWrapperMessage(message_descriptor):
+ return self._WrapperMessageToJsonObject(message)
+ if full_name in _WKTJSONMETHODS:
+ return methodcaller(_WKTJSONMETHODS[full_name][0], message)(self)
+ js = {}
+ return self._RegularMessageToJsonObject(message, js)
+ def _RegularMessageToJsonObject(self, message, js):
+ """Converts normal message according to Proto3 JSON Specification."""
+ fields = message.ListFields()
-def _FieldToJsonObject(
- field, value, including_default_value_fields=False):
- """Converts field value according to Proto3 JSON Specification."""
- if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
- return _MessageToJsonObject(value, including_default_value_fields)
- elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM:
- enum_value = field.enum_type.values_by_number.get(value, None)
- if enum_value is not None:
- return enum_value.name
- else:
- raise SerializeToJsonError('Enum field contains an integer value '
- 'which can not mapped to an enum value.')
- elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING:
- if field.type == descriptor.FieldDescriptor.TYPE_BYTES:
- # Use base64 Data encoding for bytes
- return base64.b64encode(value).decode('utf-8')
- else:
- return value
- elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL:
- return bool(value)
- elif field.cpp_type in _INT64_TYPES:
- return str(value)
- elif field.cpp_type in _FLOAT_TYPES:
- if math.isinf(value):
- if value < 0.0:
- return _NEG_INFINITY
- else:
- return _INFINITY
- if math.isnan(value):
- return _NAN
- return value
+ try:
+ for field, value in fields:
+ name = field.camelcase_name
+ if _IsMapEntry(field):
+ # Convert a map field.
+ v_field = field.message_type.fields_by_name['value']
+ js_map = {}
+ for key in value:
+ if isinstance(key, bool):
+ if key:
+ recorded_key = 'true'
+ else:
+ recorded_key = 'false'
+ else:
+ recorded_key = key
+ js_map[recorded_key] = self._FieldToJsonObject(
+ v_field, value[key])
+ js[name] = js_map
+ elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ # Convert a repeated field.
+ js[name] = [self._FieldToJsonObject(field, k)
+ for k in value]
+ else:
+ js[name] = self._FieldToJsonObject(field, value)
+
+ # Serialize default value if including_default_value_fields is True.
+ if self.including_default_value_fields:
+ message_descriptor = message.DESCRIPTOR
+ for field in message_descriptor.fields:
+ # Singular message fields and oneof fields will not be affected.
+ if ((field.label != descriptor.FieldDescriptor.LABEL_REPEATED and
+ field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE) or
+ field.containing_oneof):
+ continue
+ name = field.camelcase_name
+ if name in js:
+ # Skip the field which has been serailized already.
+ continue
+ if _IsMapEntry(field):
+ js[name] = {}
+ elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ js[name] = []
+ else:
+ js[name] = self._FieldToJsonObject(field, field.default_value)
+ except ValueError as e:
+ raise SerializeToJsonError(
+ 'Failed to serialize {0} field: {1}.'.format(field.name, e))
-def _AnyMessageToJsonObject(message, including_default):
- """Converts Any message according to Proto3 JSON Specification."""
- if not message.ListFields():
- return {}
- # Must print @type first, use OrderedDict instead of {}
- js = OrderedDict()
- type_url = message.type_url
- js['@type'] = type_url
- sub_message = _CreateMessageFromTypeUrl(type_url)
- sub_message.ParseFromString(message.value)
- message_descriptor = sub_message.DESCRIPTOR
- full_name = message_descriptor.full_name
- if _IsWrapperMessage(message_descriptor):
- js['value'] = _WrapperMessageToJsonObject(sub_message)
return js
- if full_name in _WKTJSONMETHODS:
- js['value'] = _WKTJSONMETHODS[full_name][0](sub_message, including_default)
- return js
- return _RegularMessageToJsonObject(sub_message, js, including_default)
-
-
-def _CreateMessageFromTypeUrl(type_url):
- # TODO(jieluo): Should add a way that users can register the type resolver
- # instead of the default one.
- db = symbol_database.Default()
- type_name = type_url.split('/')[-1]
- try:
- message_descriptor = db.pool.FindMessageTypeByName(type_name)
- except KeyError:
- raise TypeError(
- 'Can not find message descriptor by type_url: {0}.'.format(type_url))
- message_class = db.GetPrototype(message_descriptor)
- return message_class()
-
-
-def _GenericMessageToJsonObject(message, unused_including_default):
- """Converts message by ToJsonString according to Proto3 JSON Specification."""
- # Duration, Timestamp and FieldMask have ToJsonString method to do the
- # convert. Users can also call the method directly.
- return message.ToJsonString()
-
-
-def _ValueMessageToJsonObject(message, unused_including_default=False):
- """Converts Value message according to Proto3 JSON Specification."""
- which = message.WhichOneof('kind')
- # If the Value message is not set treat as null_value when serialize
- # to JSON. The parse back result will be different from original message.
- if which is None or which == 'null_value':
- return None
- if which == 'list_value':
- return _ListValueMessageToJsonObject(message.list_value)
- if which == 'struct_value':
- value = message.struct_value
- else:
- value = getattr(message, which)
- oneof_descriptor = message.DESCRIPTOR.fields_by_name[which]
- return _FieldToJsonObject(oneof_descriptor, value)
+ def _FieldToJsonObject(self, field, value):
+ """Converts field value according to Proto3 JSON Specification."""
+ if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ return self._MessageToJsonObject(value)
+ elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM:
+ enum_value = field.enum_type.values_by_number.get(value, None)
+ if enum_value is not None:
+ return enum_value.name
+ else:
+ raise SerializeToJsonError('Enum field contains an integer value '
+ 'which can not mapped to an enum value.')
+ elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING:
+ if field.type == descriptor.FieldDescriptor.TYPE_BYTES:
+ # Use base64 Data encoding for bytes
+ return base64.b64encode(value).decode('utf-8')
+ else:
+ return value
+ elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL:
+ return bool(value)
+ elif field.cpp_type in _INT64_TYPES:
+ return str(value)
+ elif field.cpp_type in _FLOAT_TYPES:
+ if math.isinf(value):
+ if value < 0.0:
+ return _NEG_INFINITY
+ else:
+ return _INFINITY
+ if math.isnan(value):
+ return _NAN
+ return value
+
+ def _AnyMessageToJsonObject(self, message):
+ """Converts Any message according to Proto3 JSON Specification."""
+ if not message.ListFields():
+ return {}
+ # Must print @type first, use OrderedDict instead of {}
+ js = OrderedDict()
+ type_url = message.type_url
+ js['@type'] = type_url
+ sub_message = _CreateMessageFromTypeUrl(type_url)
+ sub_message.ParseFromString(message.value)
+ message_descriptor = sub_message.DESCRIPTOR
+ full_name = message_descriptor.full_name
+ if _IsWrapperMessage(message_descriptor):
+ js['value'] = self._WrapperMessageToJsonObject(sub_message)
+ return js
+ if full_name in _WKTJSONMETHODS:
+ js['value'] = methodcaller(_WKTJSONMETHODS[full_name][0],
+ sub_message)(self)
+ return js
+ return self._RegularMessageToJsonObject(sub_message, js)
+
+ def _GenericMessageToJsonObject(self, message):
+ """Converts message according to Proto3 JSON Specification."""
+ # Duration, Timestamp and FieldMask have ToJsonString method to do the
+ # convert. Users can also call the method directly.
+ return message.ToJsonString()
+
+ def _ValueMessageToJsonObject(self, message):
+ """Converts Value message according to Proto3 JSON Specification."""
+ which = message.WhichOneof('kind')
+ # If the Value message is not set treat as null_value when serialize
+ # to JSON. The parse back result will be different from original message.
+ if which is None or which == 'null_value':
+ return None
+ if which == 'list_value':
+ return self._ListValueMessageToJsonObject(message.list_value)
+ if which == 'struct_value':
+ value = message.struct_value
+ else:
+ value = getattr(message, which)
+ oneof_descriptor = message.DESCRIPTOR.fields_by_name[which]
+ return self._FieldToJsonObject(oneof_descriptor, value)
-def _ListValueMessageToJsonObject(message, unused_including_default=False):
- """Converts ListValue message according to Proto3 JSON Specification."""
- return [_ValueMessageToJsonObject(value)
- for value in message.values]
+ def _ListValueMessageToJsonObject(self, message):
+ """Converts ListValue message according to Proto3 JSON Specification."""
+ return [self._ValueMessageToJsonObject(value)
+ for value in message.values]
+ def _StructMessageToJsonObject(self, message):
+ """Converts Struct message according to Proto3 JSON Specification."""
+ fields = message.fields
+ ret = {}
+ for key in fields:
+ ret[key] = self._ValueMessageToJsonObject(fields[key])
+ return ret
-def _StructMessageToJsonObject(message, unused_including_default=False):
- """Converts Struct message according to Proto3 JSON Specification."""
- fields = message.fields
- ret = {}
- for key in fields:
- ret[key] = _ValueMessageToJsonObject(fields[key])
- return ret
+ def _WrapperMessageToJsonObject(self, message):
+ return self._FieldToJsonObject(
+ message.DESCRIPTOR.fields_by_name['value'], message.value)
def _IsWrapperMessage(message_descriptor):
return message_descriptor.file.name == 'google/protobuf/wrappers.proto'
-def _WrapperMessageToJsonObject(message):
- return _FieldToJsonObject(
- message.DESCRIPTOR.fields_by_name['value'], message.value)
-
-
def _DuplicateChecker(js):
result = {}
for name, value in js:
@@ -304,12 +291,27 @@ def _DuplicateChecker(js):
return result
-def Parse(text, message):
+def _CreateMessageFromTypeUrl(type_url):
+ # TODO(jieluo): Should add a way that users can register the type resolver
+ # instead of the default one.
+ db = symbol_database.Default()
+ type_name = type_url.split('/')[-1]
+ try:
+ message_descriptor = db.pool.FindMessageTypeByName(type_name)
+ except KeyError:
+ raise TypeError(
+ 'Can not find message descriptor by type_url: {0}.'.format(type_url))
+ message_class = db.GetPrototype(message_descriptor)
+ return message_class()
+
+
+def Parse(text, message, ignore_unknown_fields=False):
"""Parses a JSON representation of a protocol message into a message.
Args:
text: Message JSON representation.
message: A protocol beffer message to merge into.
+ ignore_unknown_fields: If True, do not raise errors for unknown fields.
Returns:
The same message passed as argument.
@@ -326,213 +328,217 @@ def Parse(text, message):
js = json.loads(text, object_pairs_hook=_DuplicateChecker)
except ValueError as e:
raise ParseError('Failed to load JSON: {0}.'.format(str(e)))
- _ConvertMessage(js, message)
+ parser = _Parser(ignore_unknown_fields)
+ parser.ConvertMessage(js, message)
return message
-def _ConvertFieldValuePair(js, message):
- """Convert field value pairs into regular message.
+_INT_OR_FLOAT = six.integer_types + (float,)
- Args:
- js: A JSON object to convert the field value pairs.
- message: A regular protocol message to record the data.
- Raises:
- ParseError: In case of problems converting.
- """
- names = []
- message_descriptor = message.DESCRIPTOR
- for name in js:
- try:
- field = message_descriptor.fields_by_camelcase_name.get(name, None)
- if not field:
- raise ParseError(
- 'Message type "{0}" has no field named "{1}".'.format(
- message_descriptor.full_name, name))
- if name in names:
- raise ParseError(
- 'Message type "{0}" should not have multiple "{1}" fields.'.format(
- message.DESCRIPTOR.full_name, name))
- names.append(name)
- # Check no other oneof field is parsed.
- if field.containing_oneof is not None:
- oneof_name = field.containing_oneof.name
- if oneof_name in names:
- raise ParseError('Message type "{0}" should not have multiple "{1}" '
- 'oneof fields.'.format(
- message.DESCRIPTOR.full_name, oneof_name))
- names.append(oneof_name)
-
- value = js[name]
- if value is None:
- message.ClearField(field.name)
- continue
-
- # Parse field value.
- if _IsMapEntry(field):
- message.ClearField(field.name)
- _ConvertMapFieldValue(value, message, field)
- elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
- message.ClearField(field.name)
- if not isinstance(value, list):
- raise ParseError('repeated field {0} must be in [] which is '
- '{1}.'.format(name, value))
- if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
- # Repeated message field.
- for item in value:
- sub_message = getattr(message, field.name).add()
- # None is a null_value in Value.
- if (item is None and
- sub_message.DESCRIPTOR.full_name != 'google.protobuf.Value'):
- raise ParseError('null is not allowed to be used as an element'
- ' in a repeated field.')
- _ConvertMessage(item, sub_message)
- else:
- # Repeated scalar field.
- for item in value:
- if item is None:
- raise ParseError('null is not allowed to be used as an element'
- ' in a repeated field.')
- getattr(message, field.name).append(
- _ConvertScalarFieldValue(item, field))
- elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
- sub_message = getattr(message, field.name)
- _ConvertMessage(value, sub_message)
- else:
- setattr(message, field.name, _ConvertScalarFieldValue(value, field))
- except ParseError as e:
- if field and field.containing_oneof is None:
- raise ParseError('Failed to parse {0} field: {1}'.format(name, e))
- else:
- raise ParseError(str(e))
- except ValueError as e:
- raise ParseError('Failed to parse {0} field: {1}.'.format(name, e))
- except TypeError as e:
- raise ParseError('Failed to parse {0} field: {1}.'.format(name, e))
+class _Parser(object):
+ """JSON format parser for protocol message."""
+ def __init__(self,
+ ignore_unknown_fields):
+ self.ignore_unknown_fields = ignore_unknown_fields
-def _ConvertMessage(value, message):
- """Convert a JSON object into a message.
+ def ConvertMessage(self, value, message):
+ """Convert a JSON object into a message.
- Args:
- value: A JSON object.
- message: A WKT or regular protocol message to record the data.
+ Args:
+ value: A JSON object.
+ message: A WKT or regular protocol message to record the data.
- Raises:
- ParseError: In case of convert problems.
- """
- message_descriptor = message.DESCRIPTOR
- full_name = message_descriptor.full_name
- if _IsWrapperMessage(message_descriptor):
- _ConvertWrapperMessage(value, message)
- elif full_name in _WKTJSONMETHODS:
- _WKTJSONMETHODS[full_name][1](value, message)
- else:
- _ConvertFieldValuePair(value, message)
-
-
-def _ConvertAnyMessage(value, message):
- """Convert a JSON representation into Any message."""
- if isinstance(value, dict) and not value:
- return
- try:
- type_url = value['@type']
- except KeyError:
- raise ParseError('@type is missing when parsing any message.')
-
- sub_message = _CreateMessageFromTypeUrl(type_url)
- message_descriptor = sub_message.DESCRIPTOR
- full_name = message_descriptor.full_name
- if _IsWrapperMessage(message_descriptor):
- _ConvertWrapperMessage(value['value'], sub_message)
- elif full_name in _WKTJSONMETHODS:
- _WKTJSONMETHODS[full_name][1](value['value'], sub_message)
- else:
- del value['@type']
- _ConvertFieldValuePair(value, sub_message)
- # Sets Any message
- message.value = sub_message.SerializeToString()
- message.type_url = type_url
-
-
-def _ConvertGenericMessage(value, message):
- """Convert a JSON representation into message with FromJsonString."""
- # Durantion, Timestamp, FieldMask have FromJsonString method to do the
- # convert. Users can also call the method directly.
- message.FromJsonString(value)
+ Raises:
+ ParseError: In case of convert problems.
+ """
+ message_descriptor = message.DESCRIPTOR
+ full_name = message_descriptor.full_name
+ if _IsWrapperMessage(message_descriptor):
+ self._ConvertWrapperMessage(value, message)
+ elif full_name in _WKTJSONMETHODS:
+ methodcaller(_WKTJSONMETHODS[full_name][1], value, message)(self)
+ else:
+ self._ConvertFieldValuePair(value, message)
+
+ def _ConvertFieldValuePair(self, js, message):
+ """Convert field value pairs into regular message.
+
+ Args:
+ js: A JSON object to convert the field value pairs.
+ message: A regular protocol message to record the data.
+
+ Raises:
+ ParseError: In case of problems converting.
+ """
+ names = []
+ message_descriptor = message.DESCRIPTOR
+ for name in js:
+ try:
+ field = message_descriptor.fields_by_camelcase_name.get(name, None)
+ if not field:
+ if self.ignore_unknown_fields:
+ continue
+ raise ParseError(
+ 'Message type "{0}" has no field named "{1}".'.format(
+ message_descriptor.full_name, name))
+ if name in names:
+ raise ParseError('Message type "{0}" should not have multiple '
+ '"{1}" fields.'.format(
+ message.DESCRIPTOR.full_name, name))
+ names.append(name)
+ # Check no other oneof field is parsed.
+ if field.containing_oneof is not None:
+ oneof_name = field.containing_oneof.name
+ if oneof_name in names:
+ raise ParseError('Message type "{0}" should not have multiple '
+ '"{1}" oneof fields.'.format(
+ message.DESCRIPTOR.full_name, oneof_name))
+ names.append(oneof_name)
+
+ value = js[name]
+ if value is None:
+ message.ClearField(field.name)
+ continue
+ # Parse field value.
+ if _IsMapEntry(field):
+ message.ClearField(field.name)
+ self._ConvertMapFieldValue(value, message, field)
+ elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ message.ClearField(field.name)
+ if not isinstance(value, list):
+ raise ParseError('repeated field {0} must be in [] which is '
+ '{1}.'.format(name, value))
+ if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ # Repeated message field.
+ for item in value:
+ sub_message = getattr(message, field.name).add()
+ # None is a null_value in Value.
+ if (item is None and
+ sub_message.DESCRIPTOR.full_name != 'google.protobuf.Value'):
+ raise ParseError('null is not allowed to be used as an element'
+ ' in a repeated field.')
+ self.ConvertMessage(item, sub_message)
+ else:
+ # Repeated scalar field.
+ for item in value:
+ if item is None:
+ raise ParseError('null is not allowed to be used as an element'
+ ' in a repeated field.')
+ getattr(message, field.name).append(
+ _ConvertScalarFieldValue(item, field))
+ elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ sub_message = getattr(message, field.name)
+ self.ConvertMessage(value, sub_message)
+ else:
+ setattr(message, field.name, _ConvertScalarFieldValue(value, field))
+ except ParseError as e:
+ if field and field.containing_oneof is None:
+ raise ParseError('Failed to parse {0} field: {1}'.format(name, e))
+ else:
+ raise ParseError(str(e))
+ except ValueError as e:
+ raise ParseError('Failed to parse {0} field: {1}.'.format(name, e))
+ except TypeError as e:
+ raise ParseError('Failed to parse {0} field: {1}.'.format(name, e))
+
+ def _ConvertAnyMessage(self, value, message):
+ """Convert a JSON representation into Any message."""
+ if isinstance(value, dict) and not value:
+ return
+ try:
+ type_url = value['@type']
+ except KeyError:
+ raise ParseError('@type is missing when parsing any message.')
+
+ sub_message = _CreateMessageFromTypeUrl(type_url)
+ message_descriptor = sub_message.DESCRIPTOR
+ full_name = message_descriptor.full_name
+ if _IsWrapperMessage(message_descriptor):
+ self._ConvertWrapperMessage(value['value'], sub_message)
+ elif full_name in _WKTJSONMETHODS:
+ methodcaller(
+ _WKTJSONMETHODS[full_name][1], value['value'], sub_message)(self)
+ else:
+ del value['@type']
+ self._ConvertFieldValuePair(value, sub_message)
+ # Sets Any message
+ message.value = sub_message.SerializeToString()
+ message.type_url = type_url
+
+ def _ConvertGenericMessage(self, value, message):
+ """Convert a JSON representation into message with FromJsonString."""
+ # Durantion, Timestamp, FieldMask have FromJsonString method to do the
+ # convert. Users can also call the method directly.
+ message.FromJsonString(value)
+
+ def _ConvertValueMessage(self, value, message):
+ """Convert a JSON representation into Value message."""
+ if isinstance(value, dict):
+ self._ConvertStructMessage(value, message.struct_value)
+ elif isinstance(value, list):
+ self. _ConvertListValueMessage(value, message.list_value)
+ elif value is None:
+ message.null_value = 0
+ elif isinstance(value, bool):
+ message.bool_value = value
+ elif isinstance(value, six.string_types):
+ message.string_value = value
+ elif isinstance(value, _INT_OR_FLOAT):
+ message.number_value = value
+ else:
+ raise ParseError('Unexpected type for Value message.')
-_INT_OR_FLOAT = six.integer_types + (float,)
+ def _ConvertListValueMessage(self, value, message):
+ """Convert a JSON representation into ListValue message."""
+ if not isinstance(value, list):
+ raise ParseError(
+ 'ListValue must be in [] which is {0}.'.format(value))
+ message.ClearField('values')
+ for item in value:
+ self._ConvertValueMessage(item, message.values.add())
+
+ def _ConvertStructMessage(self, value, message):
+ """Convert a JSON representation into Struct message."""
+ if not isinstance(value, dict):
+ raise ParseError(
+ 'Struct must be in a dict which is {0}.'.format(value))
+ for key in value:
+ self._ConvertValueMessage(value[key], message.fields[key])
+ return
+ def _ConvertWrapperMessage(self, value, message):
+ """Convert a JSON representation into Wrapper message."""
+ field = message.DESCRIPTOR.fields_by_name['value']
+ setattr(message, 'value', _ConvertScalarFieldValue(value, field))
-def _ConvertValueMessage(value, message):
- """Convert a JSON representation into Value message."""
- if isinstance(value, dict):
- _ConvertStructMessage(value, message.struct_value)
- elif isinstance(value, list):
- _ConvertListValueMessage(value, message.list_value)
- elif value is None:
- message.null_value = 0
- elif isinstance(value, bool):
- message.bool_value = value
- elif isinstance(value, six.string_types):
- message.string_value = value
- elif isinstance(value, _INT_OR_FLOAT):
- message.number_value = value
- else:
- raise ParseError('Unexpected type for Value message.')
-
-
-def _ConvertListValueMessage(value, message):
- """Convert a JSON representation into ListValue message."""
- if not isinstance(value, list):
- raise ParseError(
- 'ListValue must be in [] which is {0}.'.format(value))
- message.ClearField('values')
- for item in value:
- _ConvertValueMessage(item, message.values.add())
-
-
-def _ConvertStructMessage(value, message):
- """Convert a JSON representation into Struct message."""
- if not isinstance(value, dict):
- raise ParseError(
- 'Struct must be in a dict which is {0}.'.format(value))
- for key in value:
- _ConvertValueMessage(value[key], message.fields[key])
- return
-
-
-def _ConvertWrapperMessage(value, message):
- """Convert a JSON representation into Wrapper message."""
- field = message.DESCRIPTOR.fields_by_name['value']
- setattr(message, 'value', _ConvertScalarFieldValue(value, field))
-
-
-def _ConvertMapFieldValue(value, message, field):
- """Convert map field value for a message map field.
+ def _ConvertMapFieldValue(self, value, message, field):
+ """Convert map field value for a message map field.
- Args:
- value: A JSON object to convert the map field value.
- message: A protocol message to record the converted data.
- field: The descriptor of the map field to be converted.
+ Args:
+ value: A JSON object to convert the map field value.
+ message: A protocol message to record the converted data.
+ field: The descriptor of the map field to be converted.
- Raises:
- ParseError: In case of convert problems.
- """
- if not isinstance(value, dict):
- raise ParseError(
- 'Map field {0} must be in a dict which is {1}.'.format(
- field.name, value))
- key_field = field.message_type.fields_by_name['key']
- value_field = field.message_type.fields_by_name['value']
- for key in value:
- key_value = _ConvertScalarFieldValue(key, key_field, True)
- if value_field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
- _ConvertMessage(value[key], getattr(message, field.name)[key_value])
- else:
- getattr(message, field.name)[key_value] = _ConvertScalarFieldValue(
- value[key], value_field)
+ Raises:
+ ParseError: In case of convert problems.
+ """
+ if not isinstance(value, dict):
+ raise ParseError(
+ 'Map field {0} must be in a dict which is {1}.'.format(
+ field.name, value))
+ key_field = field.message_type.fields_by_name['key']
+ value_field = field.message_type.fields_by_name['value']
+ for key in value:
+ key_value = _ConvertScalarFieldValue(key, key_field, True)
+ if value_field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ self.ConvertMessage(value[key], getattr(
+ message, field.name)[key_value])
+ else:
+ getattr(message, field.name)[key_value] = _ConvertScalarFieldValue(
+ value[key], value_field)
def _ConvertScalarFieldValue(value, field, require_str=False):
@@ -641,18 +647,18 @@ def _ConvertBool(value, require_str):
return value
_WKTJSONMETHODS = {
- 'google.protobuf.Any': [_AnyMessageToJsonObject,
- _ConvertAnyMessage],
- 'google.protobuf.Duration': [_GenericMessageToJsonObject,
- _ConvertGenericMessage],
- 'google.protobuf.FieldMask': [_GenericMessageToJsonObject,
- _ConvertGenericMessage],
- 'google.protobuf.ListValue': [_ListValueMessageToJsonObject,
- _ConvertListValueMessage],
- 'google.protobuf.Struct': [_StructMessageToJsonObject,
- _ConvertStructMessage],
- 'google.protobuf.Timestamp': [_GenericMessageToJsonObject,
- _ConvertGenericMessage],
- 'google.protobuf.Value': [_ValueMessageToJsonObject,
- _ConvertValueMessage]
+ 'google.protobuf.Any': ['_AnyMessageToJsonObject',
+ '_ConvertAnyMessage'],
+ 'google.protobuf.Duration': ['_GenericMessageToJsonObject',
+ '_ConvertGenericMessage'],
+ 'google.protobuf.FieldMask': ['_GenericMessageToJsonObject',
+ '_ConvertGenericMessage'],
+ 'google.protobuf.ListValue': ['_ListValueMessageToJsonObject',
+ '_ConvertListValueMessage'],
+ 'google.protobuf.Struct': ['_StructMessageToJsonObject',
+ '_ConvertStructMessage'],
+ 'google.protobuf.Timestamp': ['_GenericMessageToJsonObject',
+ '_ConvertGenericMessage'],
+ 'google.protobuf.Value': ['_ValueMessageToJsonObject',
+ '_ConvertValueMessage']
}
diff --git a/python/google/protobuf/pyext/descriptor.cc b/python/google/protobuf/pyext/descriptor.cc
index 23557538..e6ef5ef5 100644
--- a/python/google/protobuf/pyext/descriptor.cc
+++ b/python/google/protobuf/pyext/descriptor.cc
@@ -172,12 +172,16 @@ template<>
const FileDescriptor* GetFileDescriptor(const OneofDescriptor* descriptor) {
return descriptor->containing_type()->file();
}
+template<>
+const FileDescriptor* GetFileDescriptor(const MethodDescriptor* descriptor) {
+ return descriptor->service()->file();
+}
// Converts options into a Python protobuf, and cache the result.
//
// This is a bit tricky because options can contain extension fields defined in
// the same proto file. In this case the options parsed from the serialized_pb
-// have unkown fields, and we need to parse them again.
+// have unknown fields, and we need to parse them again.
//
// Always returns a new reference.
template<class DescriptorClass>
@@ -204,11 +208,12 @@ static PyObject* GetOrBuildOptions(const DescriptorClass *descriptor) {
cdescriptor_pool::GetMessageClass(pool, message_type));
if (message_class == NULL) {
// The Options message was not found in the current DescriptorPool.
- // In this case, there cannot be extensions to these options, and we can
- // try to use the basic pool instead.
+ // This means that the pool cannot contain any extensions to the Options
+ // message either, so falling back to the basic pool we can only increase
+ // the chances of successfully parsing the options.
PyErr_Clear();
- message_class = cdescriptor_pool::GetMessageClass(
- GetDefaultDescriptorPool(), message_type);
+ pool = GetDefaultDescriptorPool();
+ message_class = cdescriptor_pool::GetMessageClass(pool, message_type);
}
if (message_class == NULL) {
PyErr_Format(PyExc_TypeError, "Could not retrieve class for Options: %s",
@@ -248,7 +253,7 @@ static PyObject* GetOrBuildOptions(const DescriptorClass *descriptor) {
// Cache the result.
Py_INCREF(value.get());
- (*pool->descriptor_options)[descriptor] = value.get();
+ (*descriptor_options)[descriptor] = value.get();
return value.release();
}
@@ -1091,7 +1096,7 @@ PyTypeObject PyEnumDescriptor_Type = {
0, // tp_weaklistoffset
0, // tp_iter
0, // tp_iternext
- enum_descriptor::Methods, // tp_getset
+ enum_descriptor::Methods, // tp_methods
0, // tp_members
enum_descriptor::Getters, // tp_getset
&descriptor::PyBaseDescriptor_Type, // tp_base
@@ -1275,6 +1280,10 @@ static PyObject* GetExtensionsByName(PyFileDescriptor* self, void *closure) {
return NewFileExtensionsByName(_GetDescriptor(self));
}
+static PyObject* GetServicesByName(PyFileDescriptor* self, void *closure) {
+ return NewFileServicesByName(_GetDescriptor(self));
+}
+
static PyObject* GetDependencies(PyFileDescriptor* self, void *closure) {
return NewFileDependencies(_GetDescriptor(self));
}
@@ -1324,6 +1333,7 @@ static PyGetSetDef Getters[] = {
{ "enum_types_by_name", (getter)GetEnumTypesByName, NULL, "Enums by name"},
{ "extensions_by_name", (getter)GetExtensionsByName, NULL,
"Extensions by name"},
+ { "services_by_name", (getter)GetServicesByName, NULL, "Services by name"},
{ "dependencies", (getter)GetDependencies, NULL, "Dependencies"},
{ "public_dependencies", (getter)GetPublicDependencies, NULL, "Dependencies"},
@@ -1452,16 +1462,45 @@ static PyObject* GetContainingType(PyBaseDescriptor *self, void *closure) {
}
}
+static PyObject* GetHasOptions(PyBaseDescriptor *self, void *closure) {
+ const OneofOptions& options(_GetDescriptor(self)->options());
+ if (&options != &OneofOptions::default_instance()) {
+ Py_RETURN_TRUE;
+ } else {
+ Py_RETURN_FALSE;
+ }
+}
+static int SetHasOptions(PyBaseDescriptor *self, PyObject *value,
+ void *closure) {
+ return CheckCalledFromGeneratedFile("has_options");
+}
+
+static PyObject* GetOptions(PyBaseDescriptor *self) {
+ return GetOrBuildOptions(_GetDescriptor(self));
+}
+
+static int SetOptions(PyBaseDescriptor *self, PyObject *value,
+ void *closure) {
+ return CheckCalledFromGeneratedFile("_options");
+}
+
static PyGetSetDef Getters[] = {
{ "name", (getter)GetName, NULL, "Name"},
{ "full_name", (getter)GetFullName, NULL, "Full name"},
{ "index", (getter)GetIndex, NULL, "Index"},
{ "containing_type", (getter)GetContainingType, NULL, "Containing type"},
+ { "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"},
+ { "_options", (getter)NULL, (setter)SetOptions, "Options"},
{ "fields", (getter)GetFields, NULL, "Fields"},
{NULL}
};
+static PyMethodDef Methods[] = {
+ { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS },
+ {NULL}
+};
+
} // namespace oneof_descriptor
PyTypeObject PyOneofDescriptor_Type = {
@@ -1492,7 +1531,7 @@ PyTypeObject PyOneofDescriptor_Type = {
0, // tp_weaklistoffset
0, // tp_iter
0, // tp_iternext
- 0, // tp_methods
+ oneof_descriptor::Methods, // tp_methods
0, // tp_members
oneof_descriptor::Getters, // tp_getset
&descriptor::PyBaseDescriptor_Type, // tp_base
@@ -1504,6 +1543,222 @@ PyObject* PyOneofDescriptor_FromDescriptor(
&PyOneofDescriptor_Type, oneof_descriptor, NULL);
}
+namespace service_descriptor {
+
+// Unchecked accessor to the C++ pointer.
+static const ServiceDescriptor* _GetDescriptor(
+ PyBaseDescriptor *self) {
+ return reinterpret_cast<const ServiceDescriptor*>(self->descriptor);
+}
+
+static PyObject* GetName(PyBaseDescriptor* self, void *closure) {
+ return PyString_FromCppString(_GetDescriptor(self)->name());
+}
+
+static PyObject* GetFullName(PyBaseDescriptor* self, void *closure) {
+ return PyString_FromCppString(_GetDescriptor(self)->full_name());
+}
+
+static PyObject* GetIndex(PyBaseDescriptor *self, void *closure) {
+ return PyInt_FromLong(_GetDescriptor(self)->index());
+}
+
+static PyObject* GetMethods(PyBaseDescriptor* self, void *closure) {
+ return NewServiceMethodsSeq(_GetDescriptor(self));
+}
+
+static PyObject* GetMethodsByName(PyBaseDescriptor* self, void *closure) {
+ return NewServiceMethodsByName(_GetDescriptor(self));
+}
+
+static PyObject* FindMethodByName(PyBaseDescriptor *self, PyObject* arg) {
+ Py_ssize_t name_size;
+ char* name;
+ if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
+ return NULL;
+ }
+
+ const MethodDescriptor* method_descriptor =
+ _GetDescriptor(self)->FindMethodByName(string(name, name_size));
+ if (method_descriptor == NULL) {
+ PyErr_Format(PyExc_KeyError, "Couldn't find method %.200s", name);
+ return NULL;
+ }
+
+ return PyMethodDescriptor_FromDescriptor(method_descriptor);
+}
+
+static PyObject* GetOptions(PyBaseDescriptor *self) {
+ return GetOrBuildOptions(_GetDescriptor(self));
+}
+
+static PyObject* CopyToProto(PyBaseDescriptor *self, PyObject *target) {
+ return CopyToPythonProto<ServiceDescriptorProto>(_GetDescriptor(self),
+ target);
+}
+
+static PyGetSetDef Getters[] = {
+ { "name", (getter)GetName, NULL, "Name", NULL},
+ { "full_name", (getter)GetFullName, NULL, "Full name", NULL},
+ { "index", (getter)GetIndex, NULL, "Index", NULL},
+
+ { "methods", (getter)GetMethods, NULL, "Methods", NULL},
+ { "methods_by_name", (getter)GetMethodsByName, NULL, "Methods by name", NULL},
+ {NULL}
+};
+
+static PyMethodDef Methods[] = {
+ { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS },
+ { "CopyToProto", (PyCFunction)CopyToProto, METH_O, },
+ { "FindMethodByName", (PyCFunction)FindMethodByName, METH_O },
+ {NULL}
+};
+
+} // namespace service_descriptor
+
+PyTypeObject PyServiceDescriptor_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ FULL_MODULE_NAME ".ServiceDescriptor", // tp_name
+ sizeof(PyBaseDescriptor), // tp_basicsize
+ 0, // tp_itemsize
+ 0, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ 0, // tp_as_sequence
+ 0, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT, // tp_flags
+ "A Service Descriptor", // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ 0, // tp_richcompare
+ 0, // tp_weaklistoffset
+ 0, // tp_iter
+ 0, // tp_iternext
+ service_descriptor::Methods, // tp_methods
+ 0, // tp_members
+ service_descriptor::Getters, // tp_getset
+ &descriptor::PyBaseDescriptor_Type, // tp_base
+};
+
+PyObject* PyServiceDescriptor_FromDescriptor(
+ const ServiceDescriptor* service_descriptor) {
+ return descriptor::NewInternedDescriptor(
+ &PyServiceDescriptor_Type, service_descriptor, NULL);
+}
+
+namespace method_descriptor {
+
+// Unchecked accessor to the C++ pointer.
+static const MethodDescriptor* _GetDescriptor(
+ PyBaseDescriptor *self) {
+ return reinterpret_cast<const MethodDescriptor*>(self->descriptor);
+}
+
+static PyObject* GetName(PyBaseDescriptor* self, void *closure) {
+ return PyString_FromCppString(_GetDescriptor(self)->name());
+}
+
+static PyObject* GetFullName(PyBaseDescriptor* self, void *closure) {
+ return PyString_FromCppString(_GetDescriptor(self)->full_name());
+}
+
+static PyObject* GetIndex(PyBaseDescriptor *self, void *closure) {
+ return PyInt_FromLong(_GetDescriptor(self)->index());
+}
+
+static PyObject* GetContainingService(PyBaseDescriptor *self, void *closure) {
+ const ServiceDescriptor* containing_service =
+ _GetDescriptor(self)->service();
+ return PyServiceDescriptor_FromDescriptor(containing_service);
+}
+
+static PyObject* GetInputType(PyBaseDescriptor *self, void *closure) {
+ const Descriptor* input_type = _GetDescriptor(self)->input_type();
+ return PyMessageDescriptor_FromDescriptor(input_type);
+}
+
+static PyObject* GetOutputType(PyBaseDescriptor *self, void *closure) {
+ const Descriptor* output_type = _GetDescriptor(self)->output_type();
+ return PyMessageDescriptor_FromDescriptor(output_type);
+}
+
+static PyObject* GetOptions(PyBaseDescriptor *self) {
+ return GetOrBuildOptions(_GetDescriptor(self));
+}
+
+static PyObject* CopyToProto(PyBaseDescriptor *self, PyObject *target) {
+ return CopyToPythonProto<MethodDescriptorProto>(_GetDescriptor(self), target);
+}
+
+static PyGetSetDef Getters[] = {
+ { "name", (getter)GetName, NULL, "Name", NULL},
+ { "full_name", (getter)GetFullName, NULL, "Full name", NULL},
+ { "index", (getter)GetIndex, NULL, "Index", NULL},
+ { "containing_service", (getter)GetContainingService, NULL,
+ "Containing service", NULL},
+ { "input_type", (getter)GetInputType, NULL, "Input type", NULL},
+ { "output_type", (getter)GetOutputType, NULL, "Output type", NULL},
+ {NULL}
+};
+
+static PyMethodDef Methods[] = {
+ { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS, },
+ { "CopyToProto", (PyCFunction)CopyToProto, METH_O, },
+ {NULL}
+};
+
+} // namespace method_descriptor
+
+PyTypeObject PyMethodDescriptor_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ FULL_MODULE_NAME ".MethodDescriptor", // tp_name
+ sizeof(PyBaseDescriptor), // tp_basicsize
+ 0, // tp_itemsize
+ 0, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ 0, // tp_as_sequence
+ 0, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT, // tp_flags
+ "A Method Descriptor", // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ 0, // tp_richcompare
+ 0, // tp_weaklistoffset
+ 0, // tp_iter
+ 0, // tp_iternext
+ method_descriptor::Methods, // tp_methods
+ 0, // tp_members
+ method_descriptor::Getters, // tp_getset
+ &descriptor::PyBaseDescriptor_Type, // tp_base
+};
+
+PyObject* PyMethodDescriptor_FromDescriptor(
+ const MethodDescriptor* method_descriptor) {
+ return descriptor::NewInternedDescriptor(
+ &PyMethodDescriptor_Type, method_descriptor, NULL);
+}
+
// Add a enum values to a type dictionary.
static bool AddEnumValues(PyTypeObject *type,
const EnumDescriptor* enum_descriptor) {
@@ -1573,6 +1828,12 @@ bool InitDescriptor() {
if (PyType_Ready(&PyOneofDescriptor_Type) < 0)
return false;
+ if (PyType_Ready(&PyServiceDescriptor_Type) < 0)
+ return false;
+
+ if (PyType_Ready(&PyMethodDescriptor_Type) < 0)
+ return false;
+
if (!InitDescriptorMappingTypes())
return false;
diff --git a/python/google/protobuf/pyext/descriptor.h b/python/google/protobuf/pyext/descriptor.h
index eb99df18..1ae0e672 100644
--- a/python/google/protobuf/pyext/descriptor.h
+++ b/python/google/protobuf/pyext/descriptor.h
@@ -47,6 +47,8 @@ extern PyTypeObject PyEnumDescriptor_Type;
extern PyTypeObject PyEnumValueDescriptor_Type;
extern PyTypeObject PyFileDescriptor_Type;
extern PyTypeObject PyOneofDescriptor_Type;
+extern PyTypeObject PyServiceDescriptor_Type;
+extern PyTypeObject PyMethodDescriptor_Type;
// Wraps a Descriptor in a Python object.
// The C++ pointer is usually borrowed from the global DescriptorPool.
@@ -60,6 +62,10 @@ PyObject* PyEnumValueDescriptor_FromDescriptor(
PyObject* PyOneofDescriptor_FromDescriptor(const OneofDescriptor* descriptor);
PyObject* PyFileDescriptor_FromDescriptor(
const FileDescriptor* file_descriptor);
+PyObject* PyServiceDescriptor_FromDescriptor(
+ const ServiceDescriptor* descriptor);
+PyObject* PyMethodDescriptor_FromDescriptor(
+ const MethodDescriptor* descriptor);
// Alternate constructor of PyFileDescriptor, used when we already have a
// serialized FileDescriptorProto that can be cached.
diff --git a/python/google/protobuf/pyext/descriptor_containers.cc b/python/google/protobuf/pyext/descriptor_containers.cc
index e505d812..d0aae9c9 100644
--- a/python/google/protobuf/pyext/descriptor_containers.cc
+++ b/python/google/protobuf/pyext/descriptor_containers.cc
@@ -608,6 +608,24 @@ static PyObject* GetItem(PyContainer* self, Py_ssize_t index) {
return _NewObj_ByIndex(self, index);
}
+static PyObject *
+SeqSubscript(PyContainer* self, PyObject* item) {
+ if (PyIndex_Check(item)) {
+ Py_ssize_t index;
+ index = PyNumber_AsSsize_t(item, PyExc_IndexError);
+ if (index == -1 && PyErr_Occurred())
+ return NULL;
+ return GetItem(self, index);
+ }
+ // Materialize the list and delegate the operation to it.
+ ScopedPyObjectPtr list(PyObject_CallFunctionObjArgs(
+ reinterpret_cast<PyObject*>(&PyList_Type), self, NULL));
+ if (list == NULL) {
+ return NULL;
+ }
+ return Py_TYPE(list.get())->tp_as_mapping->mp_subscript(list.get(), item);
+}
+
// Returns the position of the item in the sequence, of -1 if not found.
// This function never fails.
int Find(PyContainer* self, PyObject* item) {
@@ -703,14 +721,20 @@ static PyMethodDef SeqMethods[] = {
};
static PySequenceMethods SeqSequenceMethods = {
- (lenfunc)Length, // sq_length
- 0, // sq_concat
- 0, // sq_repeat
- (ssizeargfunc)GetItem, // sq_item
- 0, // sq_slice
- 0, // sq_ass_item
- 0, // sq_ass_slice
- (objobjproc)SeqContains, // sq_contains
+ (lenfunc)Length, // sq_length
+ 0, // sq_concat
+ 0, // sq_repeat
+ (ssizeargfunc)GetItem, // sq_item
+ 0, // sq_slice
+ 0, // sq_ass_item
+ 0, // sq_ass_slice
+ (objobjproc)SeqContains, // sq_contains
+};
+
+static PyMappingMethods SeqMappingMethods = {
+ (lenfunc)Length, // mp_length
+ (binaryfunc)SeqSubscript, // mp_subscript
+ 0, // mp_ass_subscript
};
PyTypeObject DescriptorSequence_Type = {
@@ -726,7 +750,7 @@ PyTypeObject DescriptorSequence_Type = {
(reprfunc)ContainerRepr, // tp_repr
0, // tp_as_number
&SeqSequenceMethods, // tp_as_sequence
- 0, // tp_as_mapping
+ &SeqMappingMethods, // tp_as_mapping
0, // tp_hash
0, // tp_call
0, // tp_str
@@ -1407,6 +1431,68 @@ PyObject* NewOneofFieldsSeq(ParentDescriptor descriptor) {
} // namespace oneof_descriptor
+namespace service_descriptor {
+
+typedef const ServiceDescriptor* ParentDescriptor;
+
+static ParentDescriptor GetDescriptor(PyContainer* self) {
+ return reinterpret_cast<ParentDescriptor>(self->descriptor);
+}
+
+namespace methods {
+
+typedef const MethodDescriptor* ItemDescriptor;
+
+static int Count(PyContainer* self) {
+ return GetDescriptor(self)->method_count();
+}
+
+static ItemDescriptor GetByName(PyContainer* self, const string& name) {
+ return GetDescriptor(self)->FindMethodByName(name);
+}
+
+static ItemDescriptor GetByIndex(PyContainer* self, int index) {
+ return GetDescriptor(self)->method(index);
+}
+
+static PyObject* NewObjectFromItem(ItemDescriptor item) {
+ return PyMethodDescriptor_FromDescriptor(item);
+}
+
+static const string& GetItemName(ItemDescriptor item) {
+ return item->name();
+}
+
+static int GetItemIndex(ItemDescriptor item) {
+ return item->index();
+}
+
+static DescriptorContainerDef ContainerDef = {
+ "ServiceMethods",
+ (CountMethod)Count,
+ (GetByIndexMethod)GetByIndex,
+ (GetByNameMethod)GetByName,
+ (GetByCamelcaseNameMethod)NULL,
+ (GetByNumberMethod)NULL,
+ (NewObjectFromItemMethod)NewObjectFromItem,
+ (GetItemNameMethod)GetItemName,
+ (GetItemCamelcaseNameMethod)NULL,
+ (GetItemNumberMethod)NULL,
+ (GetItemIndexMethod)GetItemIndex,
+};
+
+} // namespace methods
+
+PyObject* NewServiceMethodsSeq(ParentDescriptor descriptor) {
+ return descriptor::NewSequence(&methods::ContainerDef, descriptor);
+}
+
+PyObject* NewServiceMethodsByName(ParentDescriptor descriptor) {
+ return descriptor::NewMappingByName(&methods::ContainerDef, descriptor);
+}
+
+} // namespace service_descriptor
+
namespace file_descriptor {
typedef const FileDescriptor* ParentDescriptor;
@@ -1459,7 +1545,7 @@ static DescriptorContainerDef ContainerDef = {
} // namespace messages
-PyObject* NewFileMessageTypesByName(const FileDescriptor* descriptor) {
+PyObject* NewFileMessageTypesByName(ParentDescriptor descriptor) {
return descriptor::NewMappingByName(&messages::ContainerDef, descriptor);
}
@@ -1507,7 +1593,7 @@ static DescriptorContainerDef ContainerDef = {
} // namespace enums
-PyObject* NewFileEnumTypesByName(const FileDescriptor* descriptor) {
+PyObject* NewFileEnumTypesByName(ParentDescriptor descriptor) {
return descriptor::NewMappingByName(&enums::ContainerDef, descriptor);
}
@@ -1555,10 +1641,58 @@ static DescriptorContainerDef ContainerDef = {
} // namespace extensions
-PyObject* NewFileExtensionsByName(const FileDescriptor* descriptor) {
+PyObject* NewFileExtensionsByName(ParentDescriptor descriptor) {
return descriptor::NewMappingByName(&extensions::ContainerDef, descriptor);
}
+namespace services {
+
+typedef const ServiceDescriptor* ItemDescriptor;
+
+static int Count(PyContainer* self) {
+ return GetDescriptor(self)->service_count();
+}
+
+static ItemDescriptor GetByName(PyContainer* self, const string& name) {
+ return GetDescriptor(self)->FindServiceByName(name);
+}
+
+static ItemDescriptor GetByIndex(PyContainer* self, int index) {
+ return GetDescriptor(self)->service(index);
+}
+
+static PyObject* NewObjectFromItem(ItemDescriptor item) {
+ return PyServiceDescriptor_FromDescriptor(item);
+}
+
+static const string& GetItemName(ItemDescriptor item) {
+ return item->name();
+}
+
+static int GetItemIndex(ItemDescriptor item) {
+ return item->index();
+}
+
+static DescriptorContainerDef ContainerDef = {
+ "FileServices",
+ (CountMethod)Count,
+ (GetByIndexMethod)GetByIndex,
+ (GetByNameMethod)GetByName,
+ (GetByCamelcaseNameMethod)NULL,
+ (GetByNumberMethod)NULL,
+ (NewObjectFromItemMethod)NewObjectFromItem,
+ (GetItemNameMethod)GetItemName,
+ (GetItemCamelcaseNameMethod)NULL,
+ (GetItemNumberMethod)NULL,
+ (GetItemIndexMethod)GetItemIndex,
+};
+
+} // namespace services
+
+PyObject* NewFileServicesByName(const FileDescriptor* descriptor) {
+ return descriptor::NewMappingByName(&services::ContainerDef, descriptor);
+}
+
namespace dependencies {
typedef const FileDescriptor* ItemDescriptor;
diff --git a/python/google/protobuf/pyext/descriptor_containers.h b/python/google/protobuf/pyext/descriptor_containers.h
index ce40747d..83de07b6 100644
--- a/python/google/protobuf/pyext/descriptor_containers.h
+++ b/python/google/protobuf/pyext/descriptor_containers.h
@@ -43,6 +43,7 @@ class Descriptor;
class FileDescriptor;
class EnumDescriptor;
class OneofDescriptor;
+class ServiceDescriptor;
namespace python {
@@ -89,10 +90,17 @@ PyObject* NewFileEnumTypesByName(const FileDescriptor* descriptor);
PyObject* NewFileExtensionsByName(const FileDescriptor* descriptor);
+PyObject* NewFileServicesByName(const FileDescriptor* descriptor);
+
PyObject* NewFileDependencies(const FileDescriptor* descriptor);
PyObject* NewFilePublicDependencies(const FileDescriptor* descriptor);
} // namespace file_descriptor
+namespace service_descriptor {
+PyObject* NewServiceMethodsSeq(const ServiceDescriptor* descriptor);
+PyObject* NewServiceMethodsByName(const ServiceDescriptor* descriptor);
+} // namespace service_descriptor
+
} // namespace python
} // namespace protobuf
diff --git a/python/google/protobuf/pyext/descriptor_pool.cc b/python/google/protobuf/pyext/descriptor_pool.cc
index 1faff96b..cfd98690 100644
--- a/python/google/protobuf/pyext/descriptor_pool.cc
+++ b/python/google/protobuf/pyext/descriptor_pool.cc
@@ -305,6 +305,40 @@ PyObject* FindOneofByName(PyDescriptorPool* self, PyObject* arg) {
return PyOneofDescriptor_FromDescriptor(oneof_descriptor);
}
+PyObject* FindServiceByName(PyDescriptorPool* self, PyObject* arg) {
+ Py_ssize_t name_size;
+ char* name;
+ if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
+ return NULL;
+ }
+
+ const ServiceDescriptor* service_descriptor =
+ self->pool->FindServiceByName(string(name, name_size));
+ if (service_descriptor == NULL) {
+ PyErr_Format(PyExc_KeyError, "Couldn't find service %.200s", name);
+ return NULL;
+ }
+
+ return PyServiceDescriptor_FromDescriptor(service_descriptor);
+}
+
+PyObject* FindMethodByName(PyDescriptorPool* self, PyObject* arg) {
+ Py_ssize_t name_size;
+ char* name;
+ if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
+ return NULL;
+ }
+
+ const MethodDescriptor* method_descriptor =
+ self->pool->FindMethodByName(string(name, name_size));
+ if (method_descriptor == NULL) {
+ PyErr_Format(PyExc_KeyError, "Couldn't find method %.200s", name);
+ return NULL;
+ }
+
+ return PyMethodDescriptor_FromDescriptor(method_descriptor);
+}
+
PyObject* FindFileContainingSymbol(PyDescriptorPool* self, PyObject* arg) {
Py_ssize_t name_size;
char* name;
@@ -491,6 +525,10 @@ static PyMethodDef Methods[] = {
"Searches for enum type descriptor by full name." },
{ "FindOneofByName", (PyCFunction)FindOneofByName, METH_O,
"Searches for oneof descriptor by full name." },
+ { "FindServiceByName", (PyCFunction)FindServiceByName, METH_O,
+ "Searches for service descriptor by full name." },
+ { "FindMethodByName", (PyCFunction)FindMethodByName, METH_O,
+ "Searches for method descriptor by full name." },
{ "FindFileContainingSymbol", (PyCFunction)FindFileContainingSymbol, METH_O,
"Gets the FileDescriptor containing the specified symbol." },
diff --git a/python/google/protobuf/pyext/map_container.cc b/python/google/protobuf/pyext/map_container.cc
index e022406d..90438df1 100644
--- a/python/google/protobuf/pyext/map_container.cc
+++ b/python/google/protobuf/pyext/map_container.cc
@@ -39,7 +39,6 @@
#include <google/protobuf/stubs/logging.h>
#include <google/protobuf/stubs/common.h>
-#include <google/protobuf/stubs/scoped_ptr.h>
#include <google/protobuf/map_field.h>
#include <google/protobuf/map.h>
#include <google/protobuf/message.h>
diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc
index 83c151ff..a9261f20 100644
--- a/python/google/protobuf/pyext/message.cc
+++ b/python/google/protobuf/pyext/message.cc
@@ -1593,23 +1593,20 @@ struct ReleaseChild : public ChildVisitor {
parent_(parent) {}
int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) {
- return repeated_composite_container::Release(
- reinterpret_cast<RepeatedCompositeContainer*>(container));
+ return repeated_composite_container::Release(container);
}
int VisitRepeatedScalarContainer(RepeatedScalarContainer* container) {
- return repeated_scalar_container::Release(
- reinterpret_cast<RepeatedScalarContainer*>(container));
+ return repeated_scalar_container::Release(container);
}
int VisitMapContainer(MapContainer* container) {
- return reinterpret_cast<MapContainer*>(container)->Release();
+ return container->Release();
}
int VisitCMessage(CMessage* cmessage,
const FieldDescriptor* field_descriptor) {
- return ReleaseSubMessage(parent_, field_descriptor,
- reinterpret_cast<CMessage*>(cmessage));
+ return ReleaseSubMessage(parent_, field_descriptor, cmessage);
}
CMessage* parent_;
@@ -1903,7 +1900,7 @@ static bool allow_oversize_protos = false;
// Provide a method in the module to set allow_oversize_protos to a boolean
// value. This method returns the newly value of allow_oversize_protos.
-static PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg) {
+PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg) {
if (!arg || !PyBool_Check(arg)) {
PyErr_SetString(PyExc_TypeError,
"Argument to SetAllowOversizeProtos must be boolean");
@@ -3044,6 +3041,10 @@ bool InitProto2MessageModule(PyObject *m) {
&PyFileDescriptor_Type));
PyModule_AddObject(m, "OneofDescriptor", reinterpret_cast<PyObject*>(
&PyOneofDescriptor_Type));
+ PyModule_AddObject(m, "ServiceDescriptor", reinterpret_cast<PyObject*>(
+ &PyServiceDescriptor_Type));
+ PyModule_AddObject(m, "MethodDescriptor", reinterpret_cast<PyObject*>(
+ &PyMethodDescriptor_Type));
PyObject* enum_type_wrapper = PyImport_ImportModule(
"google.protobuf.internal.enum_type_wrapper");
@@ -3081,53 +3082,4 @@ bool InitProto2MessageModule(PyObject *m) {
} // namespace python
} // namespace protobuf
-static PyMethodDef ModuleMethods[] = {
- {"SetAllowOversizeProtos",
- (PyCFunction)google::protobuf::python::cmessage::SetAllowOversizeProtos,
- METH_O, "Enable/disable oversize proto parsing."},
- { NULL, NULL}
-};
-
-#if PY_MAJOR_VERSION >= 3
-static struct PyModuleDef _module = {
- PyModuleDef_HEAD_INIT,
- "_message",
- google::protobuf::python::module_docstring,
- -1,
- ModuleMethods, /* m_methods */
- NULL,
- NULL,
- NULL,
- NULL
-};
-#define INITFUNC PyInit__message
-#define INITFUNC_ERRORVAL NULL
-#else // Python 2
-#define INITFUNC init_message
-#define INITFUNC_ERRORVAL
-#endif
-
-extern "C" {
- PyMODINIT_FUNC INITFUNC(void) {
- PyObject* m;
-#if PY_MAJOR_VERSION >= 3
- m = PyModule_Create(&_module);
-#else
- m = Py_InitModule3("_message", ModuleMethods,
- google::protobuf::python::module_docstring);
-#endif
- if (m == NULL) {
- return INITFUNC_ERRORVAL;
- }
-
- if (!google::protobuf::python::InitProto2MessageModule(m)) {
- Py_DECREF(m);
- return INITFUNC_ERRORVAL;
- }
-
-#if PY_MAJOR_VERSION >= 3
- return m;
-#endif
- }
-}
} // namespace google
diff --git a/python/google/protobuf/pyext/message.h b/python/google/protobuf/pyext/message.h
index 3a4bec81..8b399e05 100644
--- a/python/google/protobuf/pyext/message.h
+++ b/python/google/protobuf/pyext/message.h
@@ -54,7 +54,7 @@ class MessageFactory;
#ifdef _SHARED_PTR_H
using std::shared_ptr;
-using ::std::string;
+using std::string;
#else
using internal::shared_ptr;
#endif
@@ -269,6 +269,8 @@ int AssureWritable(CMessage* self);
// even in the case of extensions.
PyDescriptorPool* GetDescriptorPoolForMessage(CMessage* message);
+PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg);
+
} // namespace cmessage
@@ -354,6 +356,8 @@ bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor,
extern PyObject* PickleError_class;
+bool InitProto2MessageModule(PyObject *m);
+
} // namespace python
} // namespace protobuf
diff --git a/python/google/protobuf/pyext/message_module.cc b/python/google/protobuf/pyext/message_module.cc
new file mode 100644
index 00000000..d90d9de3
--- /dev/null
+++ b/python/google/protobuf/pyext/message_module.cc
@@ -0,0 +1,88 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include <google/protobuf/pyext/message.h>
+
+static const char module_docstring[] =
+"python-proto2 is a module that can be used to enhance proto2 Python API\n"
+"performance.\n"
+"\n"
+"It provides access to the protocol buffers C++ reflection API that\n"
+"implements the basic protocol buffer functions.";
+
+static PyMethodDef ModuleMethods[] = {
+ {"SetAllowOversizeProtos",
+ (PyCFunction)google::protobuf::python::cmessage::SetAllowOversizeProtos,
+ METH_O, "Enable/disable oversize proto parsing."},
+ { NULL, NULL}
+};
+
+#if PY_MAJOR_VERSION >= 3
+static struct PyModuleDef _module = {
+ PyModuleDef_HEAD_INIT,
+ "_message",
+ module_docstring,
+ -1,
+ ModuleMethods, /* m_methods */
+ NULL,
+ NULL,
+ NULL,
+ NULL
+};
+#define INITFUNC PyInit__message
+#define INITFUNC_ERRORVAL NULL
+#else // Python 2
+#define INITFUNC init_message
+#define INITFUNC_ERRORVAL
+#endif
+
+extern "C" {
+ PyMODINIT_FUNC INITFUNC(void) {
+ PyObject* m;
+#if PY_MAJOR_VERSION >= 3
+ m = PyModule_Create(&_module);
+#else
+ m = Py_InitModule3("_message", ModuleMethods,
+ module_docstring);
+#endif
+ if (m == NULL) {
+ return INITFUNC_ERRORVAL;
+ }
+
+ if (!google::protobuf::python::InitProto2MessageModule(m)) {
+ Py_DECREF(m);
+ return INITFUNC_ERRORVAL;
+ }
+
+#if PY_MAJOR_VERSION >= 3
+ return m;
+#endif
+ }
+}
diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py
index 6f1e3c8b..c4b23c37 100755
--- a/python/google/protobuf/text_format.py
+++ b/python/google/protobuf/text_format.py
@@ -48,15 +48,15 @@ import re
import six
if six.PY3:
- long = int
+ long = int # pylint: disable=redefined-builtin,invalid-name
+# pylint: disable=g-import-not-at-top
from google.protobuf.internal import type_checkers
from google.protobuf import descriptor
from google.protobuf import text_encoding
-__all__ = ['MessageToString', 'PrintMessage', 'PrintField',
- 'PrintFieldValue', 'Merge']
-
+__all__ = ['MessageToString', 'PrintMessage', 'PrintField', 'PrintFieldValue',
+ 'Merge']
_INTEGER_CHECKERS = (type_checkers.Uint32ValueChecker(),
type_checkers.Int32ValueChecker(),
@@ -67,6 +67,7 @@ _FLOAT_NAN = re.compile('nanf?', re.IGNORECASE)
_FLOAT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_FLOAT,
descriptor.FieldDescriptor.CPPTYPE_DOUBLE])
_QUOTES = frozenset(("'", '"'))
+_ANY_FULL_TYPE_NAME = 'google.protobuf.Any'
class Error(Exception):
@@ -74,10 +75,30 @@ class Error(Exception):
class ParseError(Error):
- """Thrown in case of text parsing error."""
+ """Thrown in case of text parsing or tokenizing error."""
+
+ def __init__(self, message=None, line=None, column=None):
+ if message is not None and line is not None:
+ loc = str(line)
+ if column is not None:
+ loc += ':{}'.format(column)
+ message = '{} : {}'.format(loc, message)
+ if message is not None:
+ super(ParseError, self).__init__(message)
+ else:
+ super(ParseError, self).__init__()
+ self._line = line
+ self._column = column
+
+ def GetLine(self):
+ return self._line
+
+ def GetColumn(self):
+ return self._column
class TextWriter(object):
+
def __init__(self, as_utf8):
if six.PY2:
self._writer = io.BytesIO()
@@ -97,9 +118,15 @@ class TextWriter(object):
return self._writer.getvalue()
-def MessageToString(message, as_utf8=False, as_one_line=False,
- pointy_brackets=False, use_index_order=False,
- float_format=None, use_field_number=False):
+def MessageToString(message,
+ as_utf8=False,
+ as_one_line=False,
+ pointy_brackets=False,
+ use_index_order=False,
+ float_format=None,
+ use_field_number=False,
+ descriptor_pool=None,
+ indent=0):
"""Convert protobuf message to text format.
Floating point values can be formatted compactly with 15 digits of
@@ -119,14 +146,16 @@ def MessageToString(message, as_utf8=False, as_one_line=False,
float_format: If set, use this to specify floating point number formatting
(per the "Format Specification Mini-Language"); otherwise, str() is used.
use_field_number: If True, print field numbers instead of names.
+ descriptor_pool: A DescriptorPool used to resolve Any types.
+ indent: The indent level, in terms of spaces, for pretty print.
Returns:
A string of the text formatted protocol buffer message.
"""
out = TextWriter(as_utf8)
- printer = _Printer(out, 0, as_utf8, as_one_line,
- pointy_brackets, use_index_order, float_format,
- use_field_number)
+ printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets,
+ use_index_order, float_format, use_field_number,
+ descriptor_pool)
printer.PrintMessage(message)
result = out.getvalue()
out.close()
@@ -141,39 +170,87 @@ def _IsMapEntry(field):
field.message_type.GetOptions().map_entry)
-def PrintMessage(message, out, indent=0, as_utf8=False, as_one_line=False,
- pointy_brackets=False, use_index_order=False,
- float_format=None, use_field_number=False):
- printer = _Printer(out, indent, as_utf8, as_one_line,
- pointy_brackets, use_index_order, float_format,
- use_field_number)
+def PrintMessage(message,
+ out,
+ indent=0,
+ as_utf8=False,
+ as_one_line=False,
+ pointy_brackets=False,
+ use_index_order=False,
+ float_format=None,
+ use_field_number=False,
+ descriptor_pool=None):
+ printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets,
+ use_index_order, float_format, use_field_number,
+ descriptor_pool)
printer.PrintMessage(message)
-def PrintField(field, value, out, indent=0, as_utf8=False, as_one_line=False,
- pointy_brackets=False, use_index_order=False, float_format=None):
+def PrintField(field,
+ value,
+ out,
+ indent=0,
+ as_utf8=False,
+ as_one_line=False,
+ pointy_brackets=False,
+ use_index_order=False,
+ float_format=None):
"""Print a single field name/value pair."""
- printer = _Printer(out, indent, as_utf8, as_one_line,
- pointy_brackets, use_index_order, float_format)
+ printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets,
+ use_index_order, float_format)
printer.PrintField(field, value)
-def PrintFieldValue(field, value, out, indent=0, as_utf8=False,
- as_one_line=False, pointy_brackets=False,
+def PrintFieldValue(field,
+ value,
+ out,
+ indent=0,
+ as_utf8=False,
+ as_one_line=False,
+ pointy_brackets=False,
use_index_order=False,
float_format=None):
"""Print a single field value (not including name)."""
- printer = _Printer(out, indent, as_utf8, as_one_line,
- pointy_brackets, use_index_order, float_format)
+ printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets,
+ use_index_order, float_format)
printer.PrintFieldValue(field, value)
+def _BuildMessageFromTypeName(type_name, descriptor_pool):
+ """Returns a protobuf message instance.
+
+ Args:
+ type_name: Fully-qualified protobuf message type name string.
+ descriptor_pool: DescriptorPool instance.
+
+ Returns:
+ A Message instance of type matching type_name, or None if the a Descriptor
+ wasn't found matching type_name.
+ """
+ # pylint: disable=g-import-not-at-top
+ from google.protobuf import message_factory
+ factory = message_factory.MessageFactory(descriptor_pool)
+ try:
+ message_descriptor = descriptor_pool.FindMessageTypeByName(type_name)
+ except KeyError:
+ return None
+ message_type = factory.GetPrototype(message_descriptor)
+ return message_type()
+
+
class _Printer(object):
"""Text format printer for protocol message."""
- def __init__(self, out, indent=0, as_utf8=False, as_one_line=False,
- pointy_brackets=False, use_index_order=False, float_format=None,
- use_field_number=False):
+ def __init__(self,
+ out,
+ indent=0,
+ as_utf8=False,
+ as_one_line=False,
+ pointy_brackets=False,
+ use_index_order=False,
+ float_format=None,
+ use_field_number=False,
+ descriptor_pool=None):
"""Initialize the Printer.
Floating point values can be formatted compactly with 15 digits of
@@ -195,6 +272,7 @@ class _Printer(object):
(per the "Format Specification Mini-Language"); otherwise, str() is
used.
use_field_number: If True, print field numbers instead of names.
+ descriptor_pool: A DescriptorPool used to resolve Any types.
"""
self.out = out
self.indent = indent
@@ -204,6 +282,20 @@ class _Printer(object):
self.use_index_order = use_index_order
self.float_format = float_format
self.use_field_number = use_field_number
+ self.descriptor_pool = descriptor_pool
+
+ def _TryPrintAsAnyMessage(self, message):
+ """Serializes if message is a google.protobuf.Any field."""
+ packed_message = _BuildMessageFromTypeName(message.TypeName(),
+ self.descriptor_pool)
+ if packed_message:
+ packed_message.MergeFromString(message.value)
+ self.out.write('%s[%s]' % (self.indent * ' ', message.type_url))
+ self._PrintMessageFieldValue(packed_message)
+ self.out.write(' ' if self.as_one_line else '\n')
+ return True
+ else:
+ return False
def PrintMessage(self, message):
"""Convert protobuf message to text format.
@@ -211,6 +303,9 @@ class _Printer(object):
Args:
message: The protocol buffers message.
"""
+ if (message.DESCRIPTOR.full_name == _ANY_FULL_TYPE_NAME and
+ self.descriptor_pool and self._TryPrintAsAnyMessage(message)):
+ return
fields = message.ListFields()
if self.use_index_order:
fields.sort(key=lambda x: x[0].index)
@@ -222,8 +317,8 @@ class _Printer(object):
# of this file to work around.
#
# TODO(haberman): refactor and optimize if this becomes an issue.
- entry_submsg = field.message_type._concrete_class(
- key=key, value=value[key])
+ entry_submsg = field.message_type._concrete_class(key=key,
+ value=value[key])
self.PrintField(field, entry_submsg)
elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
for element in value:
@@ -264,6 +359,25 @@ class _Printer(object):
else:
out.write('\n')
+ def _PrintMessageFieldValue(self, value):
+ if self.pointy_brackets:
+ openb = '<'
+ closeb = '>'
+ else:
+ openb = '{'
+ closeb = '}'
+
+ if self.as_one_line:
+ self.out.write(' %s ' % openb)
+ self.PrintMessage(value)
+ self.out.write(closeb)
+ else:
+ self.out.write(' %s\n' % openb)
+ self.indent += 2
+ self.PrintMessage(value)
+ self.indent -= 2
+ self.out.write(' ' * self.indent + closeb)
+
def PrintFieldValue(self, field, value):
"""Print a single field value (not including name).
@@ -274,24 +388,8 @@ class _Printer(object):
value: The value of the field.
"""
out = self.out
- if self.pointy_brackets:
- openb = '<'
- closeb = '>'
- else:
- openb = '{'
- closeb = '}'
-
if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
- if self.as_one_line:
- out.write(' %s ' % openb)
- self.PrintMessage(value)
- out.write(closeb)
- else:
- out.write(' %s\n' % openb)
- self.indent += 2
- self.PrintMessage(value)
- self.indent -= 2
- out.write(' ' * self.indent + closeb)
+ self._PrintMessageFieldValue(value)
elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM:
enum_value = field.enum_type.values_by_number.get(value, None)
if enum_value is not None:
@@ -322,9 +420,11 @@ class _Printer(object):
out.write(str(value))
-def Parse(text, message,
- allow_unknown_extension=False, allow_field_number=False):
- """Parses an text representation of a protocol message into a message.
+def Parse(text,
+ message,
+ allow_unknown_extension=False,
+ allow_field_number=False):
+ """Parses a text representation of a protocol message into a message.
Args:
text: Message text representation.
@@ -341,13 +441,16 @@ def Parse(text, message,
"""
if not isinstance(text, str):
text = text.decode('utf-8')
- return ParseLines(text.split('\n'), message, allow_unknown_extension,
- allow_field_number)
+ return ParseLines(
+ text.split('\n'), message, allow_unknown_extension, allow_field_number)
-def Merge(text, message, allow_unknown_extension=False,
- allow_field_number=False):
- """Parses an text representation of a protocol message into a message.
+def Merge(text,
+ message,
+ allow_unknown_extension=False,
+ allow_field_number=False,
+ descriptor_pool=None):
+ """Parses a text representation of a protocol message into a message.
Like Parse(), but allows repeated values for a non-repeated field, and uses
the last one.
@@ -358,6 +461,7 @@ def Merge(text, message, allow_unknown_extension=False,
allow_unknown_extension: if True, skip over missing extensions and keep
parsing
allow_field_number: if True, both field number and field name are allowed.
+ descriptor_pool: A DescriptorPool used to resolve Any types.
Returns:
The same message passed as argument.
@@ -365,13 +469,19 @@ def Merge(text, message, allow_unknown_extension=False,
Raises:
ParseError: On text parsing problems.
"""
- return MergeLines(text.split('\n'), message, allow_unknown_extension,
- allow_field_number)
+ return MergeLines(
+ text.split('\n'),
+ message,
+ allow_unknown_extension,
+ allow_field_number,
+ descriptor_pool=descriptor_pool)
-def ParseLines(lines, message, allow_unknown_extension=False,
+def ParseLines(lines,
+ message,
+ allow_unknown_extension=False,
allow_field_number=False):
- """Parses an text representation of a protocol message into a message.
+ """Parses a text representation of a protocol message into a message.
Args:
lines: An iterable of lines of a message's text representation.
@@ -379,6 +489,7 @@ def ParseLines(lines, message, allow_unknown_extension=False,
allow_unknown_extension: if True, skip over missing extensions and keep
parsing
allow_field_number: if True, both field number and field name are allowed.
+ descriptor_pool: A DescriptorPool used to resolve Any types.
Returns:
The same message passed as argument.
@@ -390,9 +501,12 @@ def ParseLines(lines, message, allow_unknown_extension=False,
return parser.ParseLines(lines, message)
-def MergeLines(lines, message, allow_unknown_extension=False,
- allow_field_number=False):
- """Parses an text representation of a protocol message into a message.
+def MergeLines(lines,
+ message,
+ allow_unknown_extension=False,
+ allow_field_number=False,
+ descriptor_pool=None):
+ """Parses a text representation of a protocol message into a message.
Args:
lines: An iterable of lines of a message's text representation.
@@ -407,41 +521,47 @@ def MergeLines(lines, message, allow_unknown_extension=False,
Raises:
ParseError: On text parsing problems.
"""
- parser = _Parser(allow_unknown_extension, allow_field_number)
+ parser = _Parser(allow_unknown_extension,
+ allow_field_number,
+ descriptor_pool=descriptor_pool)
return parser.MergeLines(lines, message)
class _Parser(object):
"""Text format parser for protocol message."""
- def __init__(self, allow_unknown_extension=False, allow_field_number=False):
+ def __init__(self,
+ allow_unknown_extension=False,
+ allow_field_number=False,
+ descriptor_pool=None):
self.allow_unknown_extension = allow_unknown_extension
self.allow_field_number = allow_field_number
+ self.descriptor_pool = descriptor_pool
def ParseFromString(self, text, message):
- """Parses an text representation of a protocol message into a message."""
+ """Parses a text representation of a protocol message into a message."""
if not isinstance(text, str):
text = text.decode('utf-8')
return self.ParseLines(text.split('\n'), message)
def ParseLines(self, lines, message):
- """Parses an text representation of a protocol message into a message."""
+ """Parses a text representation of a protocol message into a message."""
self._allow_multiple_scalars = False
self._ParseOrMerge(lines, message)
return message
def MergeFromString(self, text, message):
- """Merges an text representation of a protocol message into a message."""
+ """Merges a text representation of a protocol message into a message."""
return self._MergeLines(text.split('\n'), message)
def MergeLines(self, lines, message):
- """Merges an text representation of a protocol message into a message."""
+ """Merges a text representation of a protocol message into a message."""
self._allow_multiple_scalars = True
self._ParseOrMerge(lines, message)
return message
def _ParseOrMerge(self, lines, message):
- """Converts an text representation of a protocol message into a message.
+ """Converts a text representation of a protocol message into a message.
Args:
lines: Lines of a message's text representation.
@@ -450,7 +570,7 @@ class _Parser(object):
Raises:
ParseError: On text parsing problems.
"""
- tokenizer = _Tokenizer(lines)
+ tokenizer = Tokenizer(lines)
while not tokenizer.AtEnd():
self._MergeField(tokenizer, message)
@@ -491,13 +611,13 @@ class _Parser(object):
'Extension "%s" not registered.' % name)
elif message_descriptor != field.containing_type:
raise tokenizer.ParseErrorPreviousToken(
- 'Extension "%s" does not extend message type "%s".' % (
- name, message_descriptor.full_name))
+ 'Extension "%s" does not extend message type "%s".' %
+ (name, message_descriptor.full_name))
tokenizer.Consume(']')
else:
- name = tokenizer.ConsumeIdentifier()
+ name = tokenizer.ConsumeIdentifierOrNumber()
if self.allow_field_number and name.isdigit():
number = ParseInteger(name, True, True)
field = message_descriptor.fields_by_number.get(number, None)
@@ -520,8 +640,8 @@ class _Parser(object):
if not field:
raise tokenizer.ParseErrorPreviousToken(
- 'Message type "%s" has no field named "%s".' % (
- message_descriptor.full_name, name))
+ 'Message type "%s" has no field named "%s".' %
+ (message_descriptor.full_name, name))
if field:
if not self._allow_multiple_scalars and field.containing_oneof:
@@ -532,9 +652,9 @@ class _Parser(object):
if which_oneof is not None and which_oneof != field.name:
raise tokenizer.ParseErrorPreviousToken(
'Field "%s" is specified along with field "%s", another member '
- 'of oneof "%s" for message type "%s".' % (
- field.name, which_oneof, field.containing_oneof.name,
- message_descriptor.full_name))
+ 'of oneof "%s" for message type "%s".' %
+ (field.name, which_oneof, field.containing_oneof.name,
+ message_descriptor.full_name))
if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
tokenizer.TryConsume(':')
@@ -543,12 +663,13 @@ class _Parser(object):
tokenizer.Consume(':')
merger = self._MergeScalarField
- if (field.label == descriptor.FieldDescriptor.LABEL_REPEATED
- and tokenizer.TryConsume('[')):
+ if (field.label == descriptor.FieldDescriptor.LABEL_REPEATED and
+ tokenizer.TryConsume('[')):
# Short repeated format, e.g. "foo: [1, 2, 3]"
while True:
merger(tokenizer, message, field)
- if tokenizer.TryConsume(']'): break
+ if tokenizer.TryConsume(']'):
+ break
tokenizer.Consume(',')
else:
@@ -563,6 +684,21 @@ class _Parser(object):
if not tokenizer.TryConsume(','):
tokenizer.TryConsume(';')
+ def _ConsumeAnyTypeUrl(self, tokenizer):
+ """Consumes a google.protobuf.Any type URL and returns the type name."""
+ # Consume "type.googleapis.com/".
+ tokenizer.ConsumeIdentifier()
+ tokenizer.Consume('.')
+ tokenizer.ConsumeIdentifier()
+ tokenizer.Consume('.')
+ tokenizer.ConsumeIdentifier()
+ tokenizer.Consume('/')
+ # Consume the fully-qualified type name.
+ name = [tokenizer.ConsumeIdentifier()]
+ while tokenizer.TryConsume('.'):
+ name.append(tokenizer.ConsumeIdentifier())
+ return '.'.join(name)
+
def _MergeMessageField(self, tokenizer, message, field):
"""Merges a single scalar field into a message.
@@ -582,7 +718,34 @@ class _Parser(object):
tokenizer.Consume('{')
end_token = '}'
- if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ if (field.message_type.full_name == _ANY_FULL_TYPE_NAME and
+ tokenizer.TryConsume('[')):
+ packed_type_name = self._ConsumeAnyTypeUrl(tokenizer)
+ tokenizer.Consume(']')
+ tokenizer.TryConsume(':')
+ if tokenizer.TryConsume('<'):
+ expanded_any_end_token = '>'
+ else:
+ tokenizer.Consume('{')
+ expanded_any_end_token = '}'
+ if not self.descriptor_pool:
+ raise ParseError('Descriptor pool required to parse expanded Any field')
+ expanded_any_sub_message = _BuildMessageFromTypeName(packed_type_name,
+ self.descriptor_pool)
+ if not expanded_any_sub_message:
+ raise ParseError('Type %s not found in descriptor pool' %
+ packed_type_name)
+ while not tokenizer.TryConsume(expanded_any_end_token):
+ if tokenizer.AtEnd():
+ raise tokenizer.ParseErrorPreviousToken('Expected "%s".' %
+ (expanded_any_end_token,))
+ self._MergeField(tokenizer, expanded_any_sub_message)
+ if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ any_message = getattr(message, field.name).add()
+ else:
+ any_message = getattr(message, field.name)
+ any_message.Pack(expanded_any_sub_message)
+ elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
if field.is_extension:
sub_message = message.Extensions[field].add()
elif is_map_entry:
@@ -628,17 +791,17 @@ class _Parser(object):
if field.type in (descriptor.FieldDescriptor.TYPE_INT32,
descriptor.FieldDescriptor.TYPE_SINT32,
descriptor.FieldDescriptor.TYPE_SFIXED32):
- value = tokenizer.ConsumeInt32()
+ value = _ConsumeInt32(tokenizer)
elif field.type in (descriptor.FieldDescriptor.TYPE_INT64,
descriptor.FieldDescriptor.TYPE_SINT64,
descriptor.FieldDescriptor.TYPE_SFIXED64):
- value = tokenizer.ConsumeInt64()
+ value = _ConsumeInt64(tokenizer)
elif field.type in (descriptor.FieldDescriptor.TYPE_UINT32,
descriptor.FieldDescriptor.TYPE_FIXED32):
- value = tokenizer.ConsumeUint32()
+ value = _ConsumeUint32(tokenizer)
elif field.type in (descriptor.FieldDescriptor.TYPE_UINT64,
descriptor.FieldDescriptor.TYPE_FIXED64):
- value = tokenizer.ConsumeUint64()
+ value = _ConsumeUint64(tokenizer)
elif field.type in (descriptor.FieldDescriptor.TYPE_FLOAT,
descriptor.FieldDescriptor.TYPE_DOUBLE):
value = tokenizer.ConsumeFloat()
@@ -753,13 +916,12 @@ def _SkipFieldValue(tokenizer):
return
if (not tokenizer.TryConsumeIdentifier() and
- not tokenizer.TryConsumeInt64() and
- not tokenizer.TryConsumeUint64() and
+ not _TryConsumeInt64(tokenizer) and not _TryConsumeUint64(tokenizer) and
not tokenizer.TryConsumeFloat()):
raise ParseError('Invalid field value: ' + tokenizer.token)
-class _Tokenizer(object):
+class Tokenizer(object):
"""Protocol buffer text representation tokenizer.
This class handles the lower level string parsing by splitting it into
@@ -768,17 +930,20 @@ class _Tokenizer(object):
It was directly ported from the Java protocol buffer API.
"""
- _WHITESPACE = re.compile('(\\s|(#.*$))+', re.MULTILINE)
+ _WHITESPACE = re.compile(r'\s+')
+ _COMMENT = re.compile(r'(\s*#.*$)', re.MULTILINE)
+ _WHITESPACE_OR_COMMENT = re.compile(r'(\s|(#.*$))+', re.MULTILINE)
_TOKEN = re.compile('|'.join([
- r'[a-zA-Z_][0-9a-zA-Z_+-]*', # an identifier
+ r'[a-zA-Z_][0-9a-zA-Z_+-]*', # an identifier
r'([0-9+-]|(\.[0-9]))[0-9a-zA-Z_.+-]*', # a number
- ] + [ # quoted str for each quote mark
+ ] + [ # quoted str for each quote mark
r'{qt}([^{qt}\n\\]|\\.)*({qt}|\\?$)'.format(qt=mark) for mark in _QUOTES
]))
- _IDENTIFIER = re.compile(r'\w+')
+ _IDENTIFIER = re.compile(r'[^\d\W]\w*')
+ _IDENTIFIER_OR_NUMBER = re.compile(r'\w+')
- def __init__(self, lines):
+ def __init__(self, lines, skip_comments=True):
self._position = 0
self._line = -1
self._column = 0
@@ -789,6 +954,9 @@ class _Tokenizer(object):
self._previous_line = 0
self._previous_column = 0
self._more_lines = True
+ self._skip_comments = skip_comments
+ self._whitespace_pattern = (skip_comments and self._WHITESPACE_OR_COMMENT
+ or self._WHITESPACE)
self._SkipWhitespace()
self.NextToken()
@@ -818,7 +986,7 @@ class _Tokenizer(object):
def _SkipWhitespace(self):
while True:
self._PopLine()
- match = self._WHITESPACE.match(self._current_line, self._column)
+ match = self._whitespace_pattern.match(self._current_line, self._column)
if not match:
break
length = len(match.group(0))
@@ -848,7 +1016,14 @@ class _Tokenizer(object):
ParseError: If the text couldn't be consumed.
"""
if not self.TryConsume(token):
- raise self._ParseError('Expected "%s".' % token)
+ raise self.ParseError('Expected "%s".' % token)
+
+ def ConsumeComment(self):
+ result = self.token
+ if not self._COMMENT.match(result):
+ raise self.ParseError('Expected comment.')
+ self.NextToken()
+ return result
def TryConsumeIdentifier(self):
try:
@@ -868,85 +1043,55 @@ class _Tokenizer(object):
"""
result = self.token
if not self._IDENTIFIER.match(result):
- raise self._ParseError('Expected identifier.')
+ raise self.ParseError('Expected identifier.')
self.NextToken()
return result
- def ConsumeInt32(self):
- """Consumes a signed 32bit integer number.
-
- Returns:
- The integer parsed.
-
- Raises:
- ParseError: If a signed 32bit integer couldn't be consumed.
- """
+ def TryConsumeIdentifierOrNumber(self):
try:
- result = ParseInteger(self.token, is_signed=True, is_long=False)
- except ValueError as e:
- raise self._ParseError(str(e))
- self.NextToken()
- return result
-
- def ConsumeUint32(self):
- """Consumes an unsigned 32bit integer number.
-
- Returns:
- The integer parsed.
-
- Raises:
- ParseError: If an unsigned 32bit integer couldn't be consumed.
- """
- try:
- result = ParseInteger(self.token, is_signed=False, is_long=False)
- except ValueError as e:
- raise self._ParseError(str(e))
- self.NextToken()
- return result
-
- def TryConsumeInt64(self):
- try:
- self.ConsumeInt64()
+ self.ConsumeIdentifierOrNumber()
return True
except ParseError:
return False
- def ConsumeInt64(self):
- """Consumes a signed 64bit integer number.
+ def ConsumeIdentifierOrNumber(self):
+ """Consumes protocol message field identifier.
Returns:
- The integer parsed.
+ Identifier string.
Raises:
- ParseError: If a signed 64bit integer couldn't be consumed.
+ ParseError: If an identifier couldn't be consumed.
"""
- try:
- result = ParseInteger(self.token, is_signed=True, is_long=True)
- except ValueError as e:
- raise self._ParseError(str(e))
+ result = self.token
+ if not self._IDENTIFIER_OR_NUMBER.match(result):
+ raise self.ParseError('Expected identifier or number.')
self.NextToken()
return result
- def TryConsumeUint64(self):
+ def TryConsumeInteger(self):
try:
- self.ConsumeUint64()
+ # Note: is_long only affects value type, not whether an error is raised.
+ self.ConsumeInteger()
return True
except ParseError:
return False
- def ConsumeUint64(self):
- """Consumes an unsigned 64bit integer number.
+ def ConsumeInteger(self, is_long=False):
+ """Consumes an integer number.
+ Args:
+ is_long: True if the value should be returned as a long integer.
Returns:
The integer parsed.
Raises:
- ParseError: If an unsigned 64bit integer couldn't be consumed.
+ ParseError: If an integer couldn't be consumed.
"""
try:
- result = ParseInteger(self.token, is_signed=False, is_long=True)
+ result = _ParseAbstractInteger(self.token, is_long=is_long)
except ValueError as e:
- raise self._ParseError(str(e))
+ raise self.ParseError(str(e))
self.NextToken()
return result
@@ -969,7 +1114,7 @@ class _Tokenizer(object):
try:
result = ParseFloat(self.token)
except ValueError as e:
- raise self._ParseError(str(e))
+ raise self.ParseError(str(e))
self.NextToken()
return result
@@ -985,7 +1130,7 @@ class _Tokenizer(object):
try:
result = ParseBool(self.token)
except ValueError as e:
- raise self._ParseError(str(e))
+ raise self.ParseError(str(e))
self.NextToken()
return result
@@ -1039,15 +1184,15 @@ class _Tokenizer(object):
"""
text = self.token
if len(text) < 1 or text[0] not in _QUOTES:
- raise self._ParseError('Expected string but found: %r' % (text,))
+ raise self.ParseError('Expected string but found: %r' % (text,))
if len(text) < 2 or text[-1] != text[0]:
- raise self._ParseError('String missing ending quote: %r' % (text,))
+ raise self.ParseError('String missing ending quote: %r' % (text,))
try:
result = text_encoding.CUnescape(text[1:-1])
except ValueError as e:
- raise self._ParseError(str(e))
+ raise self.ParseError(str(e))
self.NextToken()
return result
@@ -1055,7 +1200,7 @@ class _Tokenizer(object):
try:
result = ParseEnum(field, self.token)
except ValueError as e:
- raise self._ParseError(str(e))
+ raise self.ParseError(str(e))
self.NextToken()
return result
@@ -1068,16 +1213,15 @@ class _Tokenizer(object):
Returns:
A ParseError instance.
"""
- return ParseError('%d:%d : %s' % (
- self._previous_line + 1, self._previous_column + 1, message))
+ return ParseError(message, self._previous_line + 1,
+ self._previous_column + 1)
- def _ParseError(self, message):
+ def ParseError(self, message):
"""Creates and *returns* a ParseError for the current token."""
- return ParseError('%d:%d : %s' % (
- self._line + 1, self._column + 1, message))
+ return ParseError(message, self._line + 1, self._column + 1)
def _StringParseError(self, e):
- return self._ParseError('Couldn\'t parse string: ' + str(e))
+ return self.ParseError('Couldn\'t parse string: ' + str(e))
def NextToken(self):
"""Reads the next meaningful token."""
@@ -1092,12 +1236,124 @@ class _Tokenizer(object):
return
match = self._TOKEN.match(self._current_line, self._column)
+ if not match and not self._skip_comments:
+ match = self._COMMENT.match(self._current_line, self._column)
if match:
token = match.group(0)
self.token = token
else:
self.token = self._current_line[self._column]
+# Aliased so it can still be accessed by current visibility violators.
+# TODO(dbarnett): Migrate violators to textformat_tokenizer.
+_Tokenizer = Tokenizer # pylint: disable=invalid-name
+
+
+def _ConsumeInt32(tokenizer):
+ """Consumes a signed 32bit integer number from tokenizer.
+
+ Args:
+ tokenizer: A tokenizer used to parse the number.
+
+ Returns:
+ The integer parsed.
+
+ Raises:
+ ParseError: If a signed 32bit integer couldn't be consumed.
+ """
+ return _ConsumeInteger(tokenizer, is_signed=True, is_long=False)
+
+
+def _ConsumeUint32(tokenizer):
+ """Consumes an unsigned 32bit integer number from tokenizer.
+
+ Args:
+ tokenizer: A tokenizer used to parse the number.
+
+ Returns:
+ The integer parsed.
+
+ Raises:
+ ParseError: If an unsigned 32bit integer couldn't be consumed.
+ """
+ return _ConsumeInteger(tokenizer, is_signed=False, is_long=False)
+
+
+def _TryConsumeInt64(tokenizer):
+ try:
+ _ConsumeInt64(tokenizer)
+ return True
+ except ParseError:
+ return False
+
+
+def _ConsumeInt64(tokenizer):
+ """Consumes a signed 32bit integer number from tokenizer.
+
+ Args:
+ tokenizer: A tokenizer used to parse the number.
+
+ Returns:
+ The integer parsed.
+
+ Raises:
+ ParseError: If a signed 32bit integer couldn't be consumed.
+ """
+ return _ConsumeInteger(tokenizer, is_signed=True, is_long=True)
+
+
+def _TryConsumeUint64(tokenizer):
+ try:
+ _ConsumeUint64(tokenizer)
+ return True
+ except ParseError:
+ return False
+
+
+def _ConsumeUint64(tokenizer):
+ """Consumes an unsigned 64bit integer number from tokenizer.
+
+ Args:
+ tokenizer: A tokenizer used to parse the number.
+
+ Returns:
+ The integer parsed.
+
+ Raises:
+ ParseError: If an unsigned 64bit integer couldn't be consumed.
+ """
+ return _ConsumeInteger(tokenizer, is_signed=False, is_long=True)
+
+
+def _TryConsumeInteger(tokenizer, is_signed=False, is_long=False):
+ try:
+ _ConsumeInteger(tokenizer, is_signed=is_signed, is_long=is_long)
+ return True
+ except ParseError:
+ return False
+
+
+def _ConsumeInteger(tokenizer, is_signed=False, is_long=False):
+ """Consumes an integer number from tokenizer.
+
+ Args:
+ tokenizer: A tokenizer used to parse the number.
+ is_signed: True if a signed integer must be parsed.
+ is_long: True if a long integer must be parsed.
+
+ Returns:
+ The integer parsed.
+
+ Raises:
+ ParseError: If an integer with given characteristics couldn't be consumed.
+ """
+ try:
+ result = ParseInteger(tokenizer.token, is_signed=is_signed, is_long=is_long)
+ except ValueError as e:
+ raise tokenizer.ParseError(str(e))
+ tokenizer.NextToken()
+ return result
+
def ParseInteger(text, is_signed=False, is_long=False):
"""Parses an integer.
@@ -1114,22 +1370,39 @@ def ParseInteger(text, is_signed=False, is_long=False):
ValueError: Thrown Iff the text is not a valid integer.
"""
# Do the actual parsing. Exception handling is propagated to caller.
+ result = _ParseAbstractInteger(text, is_long=is_long)
+
+ # Check if the integer is sane. Exceptions handled by callers.
+ checker = _INTEGER_CHECKERS[2 * int(is_long) + int(is_signed)]
+ checker.CheckValue(result)
+ return result
+
+
+def _ParseAbstractInteger(text, is_long=False):
+ """Parses an integer without checking size/signedness.
+
+ Args:
+ text: The text to parse.
+ is_long: True if the value should be returned as a long integer.
+
+ Returns:
+ The integer value.
+
+ Raises:
+ ValueError: Thrown Iff the text is not a valid integer.
+ """
+ # Do the actual parsing. Exception handling is propagated to caller.
try:
# We force 32-bit values to int and 64-bit values to long to make
# alternate implementations where the distinction is more significant
# (e.g. the C++ implementation) simpler.
if is_long:
- result = long(text, 0)
+ return long(text, 0)
else:
- result = int(text, 0)
+ return int(text, 0)
except ValueError:
raise ValueError('Couldn\'t parse integer: %s' % text)
- # Check if the integer is sane. Exceptions handled by callers.
- checker = _INTEGER_CHECKERS[2 * int(is_long) + int(is_signed)]
- checker.CheckValue(result)
- return result
-
def ParseFloat(text):
"""Parse a floating point number.
@@ -1206,14 +1479,12 @@ def ParseEnum(field, value):
# Identifier.
enum_value = enum_descriptor.values_by_name.get(value, None)
if enum_value is None:
- raise ValueError(
- 'Enum type "%s" has no value named %s.' % (
- enum_descriptor.full_name, value))
+ raise ValueError('Enum type "%s" has no value named %s.' %
+ (enum_descriptor.full_name, value))
else:
# Numeric value.
enum_value = enum_descriptor.values_by_number.get(number, None)
if enum_value is None:
- raise ValueError(
- 'Enum type "%s" has no value with number %d.' % (
- enum_descriptor.full_name, number))
+ raise ValueError('Enum type "%s" has no value with number %d.' %
+ (enum_descriptor.full_name, number))
return enum_value.number