aboutsummaryrefslogtreecommitdiffhomepage
path: root/python/google/protobuf/text_format.py
diff options
context:
space:
mode:
authorGravatar Jisi Liu <jisi.liu@gmail.com>2017-07-18 15:38:30 -0700
committerGravatar Jisi Liu <jisi.liu@gmail.com>2017-07-18 15:38:30 -0700
commit09354db1434859a31a3c81abebcc4018d42f2715 (patch)
treeb87c7cdc2255e6c8062ab92b4082665cd698d753 /python/google/protobuf/text_format.py
parent9053033a5076f82cf18b823c31f352e95e5bfd8d (diff)
Merge from Google internal for 3.4 release
Diffstat (limited to 'python/google/protobuf/text_format.py')
-rwxr-xr-xpython/google/protobuf/text_format.py107
1 files changed, 80 insertions, 27 deletions
diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py
index c216e097..aaca78ad 100755
--- a/python/google/protobuf/text_format.py
+++ b/python/google/protobuf/text_format.py
@@ -126,7 +126,8 @@ def MessageToString(message,
float_format=None,
use_field_number=False,
descriptor_pool=None,
- indent=0):
+ indent=0,
+ message_formatter=None):
"""Convert protobuf message to text format.
Floating point values can be formatted compactly with 15 digits of
@@ -148,6 +149,9 @@ def MessageToString(message,
use_field_number: If True, print field numbers instead of names.
descriptor_pool: A DescriptorPool used to resolve Any types.
indent: The indent level, in terms of spaces, for pretty print.
+ message_formatter: A function(message, indent, as_one_line): unicode|None
+ to custom format selected sub-messages (usually based on message type).
+ Use to pretty print parts of the protobuf for easier diffing.
Returns:
A string of the text formatted protocol buffer message.
@@ -155,7 +159,7 @@ def MessageToString(message,
out = TextWriter(as_utf8)
printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets,
use_index_order, float_format, use_field_number,
- descriptor_pool)
+ descriptor_pool, message_formatter)
printer.PrintMessage(message)
result = out.getvalue()
out.close()
@@ -179,10 +183,11 @@ def PrintMessage(message,
use_index_order=False,
float_format=None,
use_field_number=False,
- descriptor_pool=None):
+ descriptor_pool=None,
+ message_formatter=None):
printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets,
use_index_order, float_format, use_field_number,
- descriptor_pool)
+ descriptor_pool, message_formatter)
printer.PrintMessage(message)
@@ -194,10 +199,11 @@ def PrintField(field,
as_one_line=False,
pointy_brackets=False,
use_index_order=False,
- float_format=None):
+ float_format=None,
+ message_formatter=None):
"""Print a single field name/value pair."""
printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets,
- use_index_order, float_format)
+ use_index_order, float_format, message_formatter)
printer.PrintField(field, value)
@@ -209,10 +215,11 @@ def PrintFieldValue(field,
as_one_line=False,
pointy_brackets=False,
use_index_order=False,
- float_format=None):
+ float_format=None,
+ message_formatter=None):
"""Print a single field value (not including name)."""
printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets,
- use_index_order, float_format)
+ use_index_order, float_format, message_formatter)
printer.PrintFieldValue(field, value)
@@ -228,6 +235,9 @@ def _BuildMessageFromTypeName(type_name, descriptor_pool):
wasn't found matching type_name.
"""
# pylint: disable=g-import-not-at-top
+ if descriptor_pool is None:
+ from google.protobuf import descriptor_pool as pool_mod
+ descriptor_pool = pool_mod.Default()
from google.protobuf import symbol_database
database = symbol_database.Default()
try:
@@ -250,7 +260,8 @@ class _Printer(object):
use_index_order=False,
float_format=None,
use_field_number=False,
- descriptor_pool=None):
+ descriptor_pool=None,
+ message_formatter=None):
"""Initialize the Printer.
Floating point values can be formatted compactly with 15 digits of
@@ -273,6 +284,9 @@ class _Printer(object):
used.
use_field_number: If True, print field numbers instead of names.
descriptor_pool: A DescriptorPool used to resolve Any types.
+ message_formatter: A function(message, indent, as_one_line): unicode|None
+ to custom format selected sub-messages (usually based on message type).
+ Use to pretty print parts of the protobuf for easier diffing.
"""
self.out = out
self.indent = indent
@@ -283,6 +297,7 @@ class _Printer(object):
self.float_format = float_format
self.use_field_number = use_field_number
self.descriptor_pool = descriptor_pool
+ self.message_formatter = message_formatter
def _TryPrintAsAnyMessage(self, message):
"""Serializes if message is a google.protobuf.Any field."""
@@ -297,14 +312,27 @@ class _Printer(object):
else:
return False
+ def _TryCustomFormatMessage(self, message):
+ formatted = self.message_formatter(message, self.indent, self.as_one_line)
+ if formatted is None:
+ return False
+
+ out = self.out
+ out.write(' ' * self.indent)
+ out.write(formatted)
+ out.write(' ' if self.as_one_line else '\n')
+ return True
+
def PrintMessage(self, message):
"""Convert protobuf message to text format.
Args:
message: The protocol buffers message.
"""
+ if self.message_formatter and self._TryCustomFormatMessage(message):
+ return
if (message.DESCRIPTOR.full_name == _ANY_FULL_TYPE_NAME and
- self.descriptor_pool and self._TryPrintAsAnyMessage(message)):
+ self._TryPrintAsAnyMessage(message)):
return
fields = message.ListFields()
if self.use_index_order:
@@ -426,6 +454,22 @@ def Parse(text,
descriptor_pool=None):
"""Parses a text representation of a protocol message into a message.
+ NOTE: for historical reasons this function does not clear the input
+ message. This is different from what the binary msg.ParseFrom(...) does.
+
+ Example
+ a = MyProto()
+ a.repeated_field.append('test')
+ b = MyProto()
+
+ text_format.Parse(repr(a), b)
+ text_format.Parse(repr(a), b) # repeated_field contains ["test", "test"]
+
+ # Binary version:
+ b.ParseFromString(a.SerializeToString()) # repeated_field is now "test"
+
+ Caller is responsible for clearing the message as needed.
+
Args:
text: Message text representation.
message: A protocol buffer message to merge into.
@@ -593,11 +637,6 @@ class _Parser(object):
ParseError: In case of text parsing problems.
"""
message_descriptor = message.DESCRIPTOR
- if (hasattr(message_descriptor, 'syntax') and
- message_descriptor.syntax == 'proto3'):
- # Proto3 doesn't represent presence so we can't test if multiple
- # scalars have occurred. We have to allow them.
- self._allow_multiple_scalars = True
if tokenizer.TryConsume('['):
name = [tokenizer.ConsumeIdentifier()]
while tokenizer.TryConsume('.'):
@@ -616,7 +655,11 @@ class _Parser(object):
field = None
else:
raise tokenizer.ParseErrorPreviousToken(
- 'Extension "%s" not registered.' % name)
+ 'Extension "%s" not registered. '
+ 'Did you import the _pb2 module which defines it? '
+ 'If you are trying to place the extension in the MessageSet '
+ 'field of another message that is in an Any or MessageSet field, '
+ 'that message\'s _pb2 module must be imported as well' % name)
elif message_descriptor != field.containing_type:
raise tokenizer.ParseErrorPreviousToken(
'Extension "%s" does not extend message type "%s".' %
@@ -695,17 +738,17 @@ class _Parser(object):
def _ConsumeAnyTypeUrl(self, tokenizer):
"""Consumes a google.protobuf.Any type URL and returns the type name."""
# Consume "type.googleapis.com/".
- tokenizer.ConsumeIdentifier()
+ prefix = [tokenizer.ConsumeIdentifier()]
tokenizer.Consume('.')
- tokenizer.ConsumeIdentifier()
+ prefix.append(tokenizer.ConsumeIdentifier())
tokenizer.Consume('.')
- tokenizer.ConsumeIdentifier()
+ prefix.append(tokenizer.ConsumeIdentifier())
tokenizer.Consume('/')
# Consume the fully-qualified type name.
name = [tokenizer.ConsumeIdentifier()]
while tokenizer.TryConsume('.'):
name.append(tokenizer.ConsumeIdentifier())
- return '.'.join(name)
+ return '.'.join(prefix), '.'.join(name)
def _MergeMessageField(self, tokenizer, message, field):
"""Merges a single scalar field into a message.
@@ -728,7 +771,7 @@ class _Parser(object):
if (field.message_type.full_name == _ANY_FULL_TYPE_NAME and
tokenizer.TryConsume('[')):
- packed_type_name = self._ConsumeAnyTypeUrl(tokenizer)
+ type_url_prefix, packed_type_name = self._ConsumeAnyTypeUrl(tokenizer)
tokenizer.Consume(']')
tokenizer.TryConsume(':')
if tokenizer.TryConsume('<'):
@@ -736,8 +779,6 @@ class _Parser(object):
else:
tokenizer.Consume('{')
expanded_any_end_token = '}'
- if not self.descriptor_pool:
- raise ParseError('Descriptor pool required to parse expanded Any field')
expanded_any_sub_message = _BuildMessageFromTypeName(packed_type_name,
self.descriptor_pool)
if not expanded_any_sub_message:
@@ -752,7 +793,8 @@ class _Parser(object):
any_message = getattr(message, field.name).add()
else:
any_message = getattr(message, field.name)
- any_message.Pack(expanded_any_sub_message)
+ any_message.Pack(expanded_any_sub_message,
+ type_url_prefix=type_url_prefix)
elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
if field.is_extension:
sub_message = message.Extensions[field].add()
@@ -780,6 +822,12 @@ class _Parser(object):
else:
getattr(message, field.name)[sub_message.key] = sub_message.value
+ @staticmethod
+ def _IsProto3Syntax(message):
+ message_descriptor = message.DESCRIPTOR
+ return (hasattr(message_descriptor, 'syntax') and
+ message_descriptor.syntax == 'proto3')
+
def _MergeScalarField(self, tokenizer, message, field):
"""Merges a single scalar field into a message.
@@ -829,15 +877,20 @@ class _Parser(object):
else:
getattr(message, field.name).append(value)
else:
+ # Proto3 doesn't represent presence so we can't test if multiple scalars
+ # have occurred. We have to allow them.
+ can_check_presence = not self._IsProto3Syntax(message)
if field.is_extension:
- if not self._allow_multiple_scalars and message.HasExtension(field):
+ if (not self._allow_multiple_scalars and can_check_presence and
+ message.HasExtension(field)):
raise tokenizer.ParseErrorPreviousToken(
'Message type "%s" should not have multiple "%s" extensions.' %
(message.DESCRIPTOR.full_name, field.full_name))
else:
message.Extensions[field] = value
else:
- if not self._allow_multiple_scalars and message.HasField(field.name):
+ if (not self._allow_multiple_scalars and can_check_presence and
+ message.HasField(field.name)):
raise tokenizer.ParseErrorPreviousToken(
'Message type "%s" should not have multiple "%s" fields.' %
(message.DESCRIPTOR.full_name, field.name))
@@ -1088,7 +1141,7 @@ class Tokenizer(object):
"""
result = self.token
if not self._IDENTIFIER_OR_NUMBER.match(result):
- raise self.ParseError('Expected identifier or number.')
+ raise self.ParseError('Expected identifier or number, got %s.' % result)
self.NextToken()
return result