From 0400cca3236de1ca303af38bf81eab332d042b7c Mon Sep 17 00:00:00 2001 From: Adam Cozzette Date: Tue, 13 Mar 2018 16:37:29 -0700 Subject: Integrated internal changes from Google --- python/google/protobuf/text_format.py | 88 ++++++++++++++++++++--------------- 1 file changed, 51 insertions(+), 37 deletions(-) (limited to 'python/google/protobuf/text_format.py') diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py index 36ddd1b7..2cbd21bc 100755 --- a/python/google/protobuf/text_format.py +++ b/python/google/protobuf/text_format.py @@ -141,9 +141,11 @@ def MessageToString(message, as_one_line: Don't introduce newlines between fields. pointy_brackets: If True, use angle brackets instead of curly braces for nesting. - use_index_order: If True, print fields of a proto message using the order - defined in source code instead of the field number. By default, use the - field number order. + use_index_order: If True, fields of a proto message will be printed using + the order defined in source code instead of the field number, extensions + will be printed at the end of the message and their relative order is + determined by the extension number. By default, use the field number + order. float_format: If set, use this to specify floating point number formatting (per the "Format Specification Mini-Language"); otherwise, str() is used. use_field_number: If True, print field numbers instead of names. @@ -336,11 +338,12 @@ class _Printer(object): return fields = message.ListFields() if self.use_index_order: - fields.sort(key=lambda x: x[0].index) + fields.sort( + key=lambda x: x[0].number if x[0].is_extension else x[0].index) for field, value in fields: if _IsMapEntry(field): for key in sorted(value): - # This is slow for maps with submessage entires because it copies the + # This is slow for maps with submessage entries because it copies the # entire tree. Unfortunately this would take significant refactoring # of this file to work around. # @@ -645,6 +648,30 @@ class _Parser(object): ParseError: In case of text parsing problems. """ message_descriptor = message.DESCRIPTOR + if (message_descriptor.full_name == _ANY_FULL_TYPE_NAME and + tokenizer.TryConsume('[')): + type_url_prefix, packed_type_name = self._ConsumeAnyTypeUrl(tokenizer) + tokenizer.Consume(']') + tokenizer.TryConsume(':') + if tokenizer.TryConsume('<'): + expanded_any_end_token = '>' + else: + tokenizer.Consume('{') + expanded_any_end_token = '}' + expanded_any_sub_message = _BuildMessageFromTypeName(packed_type_name, + self.descriptor_pool) + if not expanded_any_sub_message: + raise ParseError('Type %s not found in descriptor pool' % + packed_type_name) + while not tokenizer.TryConsume(expanded_any_end_token): + if tokenizer.AtEnd(): + raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % + (expanded_any_end_token,)) + self._MergeField(tokenizer, expanded_any_sub_message) + message.Pack(expanded_any_sub_message, + type_url_prefix=type_url_prefix) + return + if tokenizer.TryConsume('['): name = [tokenizer.ConsumeIdentifier()] while tokenizer.TryConsume('.'): @@ -725,11 +752,12 @@ class _Parser(object): if (field.label == descriptor.FieldDescriptor.LABEL_REPEATED and tokenizer.TryConsume('[')): # Short repeated format, e.g. "foo: [1, 2, 3]" - while True: - merger(tokenizer, message, field) - if tokenizer.TryConsume(']'): - break - tokenizer.Consume(',') + if not tokenizer.TryConsume(']'): + while True: + merger(tokenizer, message, field) + if tokenizer.TryConsume(']'): + break + tokenizer.Consume(',') else: merger(tokenizer, message, field) @@ -777,33 +805,7 @@ class _Parser(object): tokenizer.Consume('{') end_token = '}' - if (field.message_type.full_name == _ANY_FULL_TYPE_NAME and - tokenizer.TryConsume('[')): - type_url_prefix, packed_type_name = self._ConsumeAnyTypeUrl(tokenizer) - tokenizer.Consume(']') - tokenizer.TryConsume(':') - if tokenizer.TryConsume('<'): - expanded_any_end_token = '>' - else: - tokenizer.Consume('{') - expanded_any_end_token = '}' - expanded_any_sub_message = _BuildMessageFromTypeName(packed_type_name, - self.descriptor_pool) - if not expanded_any_sub_message: - raise ParseError('Type %s not found in descriptor pool' % - packed_type_name) - while not tokenizer.TryConsume(expanded_any_end_token): - if tokenizer.AtEnd(): - raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % - (expanded_any_end_token,)) - self._MergeField(tokenizer, expanded_any_sub_message) - if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: - any_message = getattr(message, field.name).add() - else: - any_message = getattr(message, field.name) - any_message.Pack(expanded_any_sub_message, - type_url_prefix=type_url_prefix) - elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: if field.is_extension: sub_message = message.Extensions[field].add() elif is_map_entry: @@ -812,8 +814,20 @@ class _Parser(object): sub_message = getattr(message, field.name).add() else: if field.is_extension: + if (not self._allow_multiple_scalars and + message.HasExtension(field)): + raise tokenizer.ParseErrorPreviousToken( + 'Message type "%s" should not have multiple "%s" extensions.' % + (message.DESCRIPTOR.full_name, field.full_name)) sub_message = message.Extensions[field] else: + # Also apply _allow_multiple_scalars to message field. + # TODO(jieluo): Change to _allow_singular_overwrites. + if (not self._allow_multiple_scalars and + message.HasField(field.name)): + raise tokenizer.ParseErrorPreviousToken( + 'Message type "%s" should not have multiple "%s" fields.' % + (message.DESCRIPTOR.full_name, field.name)) sub_message = getattr(message, field.name) sub_message.SetInParent() -- cgit v1.2.3