From cf14183bcd5485b4a71541599ddce0b35eb71352 Mon Sep 17 00:00:00 2001 From: Jisi Liu Date: Thu, 28 Apr 2016 14:34:59 -0700 Subject: Down integrate from Google internal. --- python/google/protobuf/text_format.py | 816 ++++++++++++++++++---------------- 1 file changed, 440 insertions(+), 376 deletions(-) (limited to 'python/google/protobuf/text_format.py') diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py index a6f41ca8..6f1e3c8b 100755 --- a/python/google/protobuf/text_format.py +++ b/python/google/protobuf/text_format.py @@ -99,7 +99,7 @@ class TextWriter(object): def MessageToString(message, as_utf8=False, as_one_line=False, pointy_brackets=False, use_index_order=False, - float_format=None): + float_format=None, use_field_number=False): """Convert protobuf message to text format. Floating point values can be formatted compactly with 15 digits of @@ -118,15 +118,16 @@ def MessageToString(message, as_utf8=False, as_one_line=False, field number order. float_format: If set, use this to specify floating point number formatting (per the "Format Specification Mini-Language"); otherwise, str() is used. + use_field_number: If True, print field numbers instead of names. Returns: A string of the text formatted protocol buffer message. """ out = TextWriter(as_utf8) - PrintMessage(message, out, as_utf8=as_utf8, as_one_line=as_one_line, - pointy_brackets=pointy_brackets, - use_index_order=use_index_order, - float_format=float_format) + printer = _Printer(out, 0, as_utf8, as_one_line, + pointy_brackets, use_index_order, float_format, + use_field_number) + printer.PrintMessage(message) result = out.getvalue() out.close() if as_one_line: @@ -142,133 +143,187 @@ def _IsMapEntry(field): def PrintMessage(message, out, indent=0, as_utf8=False, as_one_line=False, pointy_brackets=False, use_index_order=False, - float_format=None): - fields = message.ListFields() - if use_index_order: - fields.sort(key=lambda x: x[0].index) - for field, value in fields: - if _IsMapEntry(field): - for key in sorted(value): - # This is slow for maps with submessage entires because it copies the - # entire tree. Unfortunately this would take significant refactoring - # of this file to work around. - # - # TODO(haberman): refactor and optimize if this becomes an issue. - entry_submsg = field.message_type._concrete_class( - key=key, value=value[key]) - PrintField(field, entry_submsg, out, indent, as_utf8, as_one_line, - pointy_brackets=pointy_brackets, - use_index_order=use_index_order, float_format=float_format) - elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: - for element in value: - PrintField(field, element, out, indent, as_utf8, as_one_line, - pointy_brackets=pointy_brackets, - use_index_order=use_index_order, - float_format=float_format) - else: - PrintField(field, value, out, indent, as_utf8, as_one_line, - pointy_brackets=pointy_brackets, - use_index_order=use_index_order, - float_format=float_format) + float_format=None, use_field_number=False): + printer = _Printer(out, indent, as_utf8, as_one_line, + pointy_brackets, use_index_order, float_format, + use_field_number) + printer.PrintMessage(message) def PrintField(field, value, out, indent=0, as_utf8=False, as_one_line=False, pointy_brackets=False, use_index_order=False, float_format=None): - """Print a single field name/value pair. For repeated fields, the value - should be a single element. - """ - - out.write(' ' * indent) - if field.is_extension: - out.write('[') - if (field.containing_type.GetOptions().message_set_wire_format and - field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and - field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL): - out.write(field.message_type.full_name) - else: - out.write(field.full_name) - out.write(']') - elif field.type == descriptor.FieldDescriptor.TYPE_GROUP: - # For groups, use the capitalized name. - out.write(field.message_type.name) - else: - out.write(field.name) - - if field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_MESSAGE: - # The colon is optional in this case, but our cross-language golden files - # don't include it. - out.write(': ') - - PrintFieldValue(field, value, out, indent, as_utf8, as_one_line, - pointy_brackets=pointy_brackets, - use_index_order=use_index_order, - float_format=float_format) - if as_one_line: - out.write(' ') - else: - out.write('\n') + """Print a single field name/value pair.""" + printer = _Printer(out, indent, as_utf8, as_one_line, + pointy_brackets, use_index_order, float_format) + printer.PrintField(field, value) def PrintFieldValue(field, value, out, indent=0, as_utf8=False, as_one_line=False, pointy_brackets=False, use_index_order=False, float_format=None): - """Print a single field value (not including name). For repeated fields, - the value should be a single element.""" + """Print a single field value (not including name).""" + printer = _Printer(out, indent, as_utf8, as_one_line, + pointy_brackets, use_index_order, float_format) + printer.PrintFieldValue(field, value) - if pointy_brackets: - openb = '<' - closeb = '>' - else: - openb = '{' - closeb = '}' - - if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: - if as_one_line: - out.write(' %s ' % openb) - PrintMessage(value, out, indent, as_utf8, as_one_line, - pointy_brackets=pointy_brackets, - use_index_order=use_index_order, - float_format=float_format) - out.write(closeb) - else: - out.write(' %s\n' % openb) - PrintMessage(value, out, indent + 2, as_utf8, as_one_line, - pointy_brackets=pointy_brackets, - use_index_order=use_index_order, - float_format=float_format) - out.write(' ' * indent + closeb) - elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM: - enum_value = field.enum_type.values_by_number.get(value, None) - if enum_value is not None: - out.write(enum_value.name) + +class _Printer(object): + """Text format printer for protocol message.""" + + def __init__(self, out, indent=0, as_utf8=False, as_one_line=False, + pointy_brackets=False, use_index_order=False, float_format=None, + use_field_number=False): + """Initialize the Printer. + + Floating point values can be formatted compactly with 15 digits of + precision (which is the most that IEEE 754 "double" can guarantee) + using float_format='.15g'. To ensure that converting to text and back to a + proto will result in an identical value, float_format='.17g' should be used. + + Args: + out: To record the text format result. + indent: The indent level for pretty print. + as_utf8: Produce text output in UTF8 format. + as_one_line: Don't introduce newlines between fields. + pointy_brackets: If True, use angle brackets instead of curly braces for + nesting. + use_index_order: If True, print fields of a proto message using the order + defined in source code instead of the field number. By default, use the + field number order. + float_format: If set, use this to specify floating point number formatting + (per the "Format Specification Mini-Language"); otherwise, str() is + used. + use_field_number: If True, print field numbers instead of names. + """ + self.out = out + self.indent = indent + self.as_utf8 = as_utf8 + self.as_one_line = as_one_line + self.pointy_brackets = pointy_brackets + self.use_index_order = use_index_order + self.float_format = float_format + self.use_field_number = use_field_number + + def PrintMessage(self, message): + """Convert protobuf message to text format. + + Args: + message: The protocol buffers message. + """ + fields = message.ListFields() + if self.use_index_order: + fields.sort(key=lambda x: x[0].index) + for field, value in fields: + if _IsMapEntry(field): + for key in sorted(value): + # This is slow for maps with submessage entires because it copies the + # entire tree. Unfortunately this would take significant refactoring + # of this file to work around. + # + # TODO(haberman): refactor and optimize if this becomes an issue. + entry_submsg = field.message_type._concrete_class( + key=key, value=value[key]) + self.PrintField(field, entry_submsg) + elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + for element in value: + self.PrintField(field, element) + else: + self.PrintField(field, value) + + def PrintField(self, field, value): + """Print a single field name/value pair.""" + out = self.out + out.write(' ' * self.indent) + if self.use_field_number: + out.write(str(field.number)) else: - out.write(str(value)) - elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING: - out.write('\"') - if isinstance(value, six.text_type): - out_value = value.encode('utf-8') + if field.is_extension: + out.write('[') + if (field.containing_type.GetOptions().message_set_wire_format and + field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and + field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL): + out.write(field.message_type.full_name) + else: + out.write(field.full_name) + out.write(']') + elif field.type == descriptor.FieldDescriptor.TYPE_GROUP: + # For groups, use the capitalized name. + out.write(field.message_type.name) + else: + out.write(field.name) + + if field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + # The colon is optional in this case, but our cross-language golden files + # don't include it. + out.write(': ') + + self.PrintFieldValue(field, value) + if self.as_one_line: + out.write(' ') else: - out_value = value - if field.type == descriptor.FieldDescriptor.TYPE_BYTES: - # We need to escape non-UTF8 chars in TYPE_BYTES field. - out_as_utf8 = False + out.write('\n') + + def PrintFieldValue(self, field, value): + """Print a single field value (not including name). + + For repeated fields, the value should be a single element. + + Args: + field: The descriptor of the field to be printed. + value: The value of the field. + """ + out = self.out + if self.pointy_brackets: + openb = '<' + closeb = '>' else: - out_as_utf8 = as_utf8 - out.write(text_encoding.CEscape(out_value, out_as_utf8)) - out.write('\"') - elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL: - if value: - out.write('true') + openb = '{' + closeb = '}' + + if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + if self.as_one_line: + out.write(' %s ' % openb) + self.PrintMessage(value) + out.write(closeb) + else: + out.write(' %s\n' % openb) + self.indent += 2 + self.PrintMessage(value) + self.indent -= 2 + out.write(' ' * self.indent + closeb) + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM: + enum_value = field.enum_type.values_by_number.get(value, None) + if enum_value is not None: + out.write(enum_value.name) + else: + out.write(str(value)) + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING: + out.write('\"') + if isinstance(value, six.text_type): + out_value = value.encode('utf-8') + else: + out_value = value + if field.type == descriptor.FieldDescriptor.TYPE_BYTES: + # We need to escape non-UTF8 chars in TYPE_BYTES field. + out_as_utf8 = False + else: + out_as_utf8 = self.as_utf8 + out.write(text_encoding.CEscape(out_value, out_as_utf8)) + out.write('\"') + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL: + if value: + out.write('true') + else: + out.write('false') + elif field.cpp_type in _FLOAT_TYPES and self.float_format is not None: + out.write('{1:{0}}'.format(self.float_format, value)) else: - out.write('false') - elif field.cpp_type in _FLOAT_TYPES and float_format is not None: - out.write('{1:{0}}'.format(float_format, value)) - else: - out.write(str(value)) + out.write(str(value)) -def Parse(text, message, allow_unknown_extension=False): +def Parse(text, message, + allow_unknown_extension=False, allow_field_number=False): """Parses an text representation of a protocol message into a message. Args: @@ -276,6 +331,7 @@ def Parse(text, message, allow_unknown_extension=False): message: A protocol buffer message to merge into. allow_unknown_extension: if True, skip over missing extensions and keep parsing + allow_field_number: if True, both field number and field name are allowed. Returns: The same message passed as argument. @@ -285,10 +341,12 @@ def Parse(text, message, allow_unknown_extension=False): """ if not isinstance(text, str): text = text.decode('utf-8') - return ParseLines(text.split('\n'), message, allow_unknown_extension) + return ParseLines(text.split('\n'), message, allow_unknown_extension, + allow_field_number) -def Merge(text, message, allow_unknown_extension=False): +def Merge(text, message, allow_unknown_extension=False, + allow_field_number=False): """Parses an text representation of a protocol message into a message. Like Parse(), but allows repeated values for a non-repeated field, and uses @@ -299,6 +357,7 @@ def Merge(text, message, allow_unknown_extension=False): message: A protocol buffer message to merge into. allow_unknown_extension: if True, skip over missing extensions and keep parsing + allow_field_number: if True, both field number and field name are allowed. Returns: The same message passed as argument. @@ -306,10 +365,12 @@ def Merge(text, message, allow_unknown_extension=False): Raises: ParseError: On text parsing problems. """ - return MergeLines(text.split('\n'), message, allow_unknown_extension) + return MergeLines(text.split('\n'), message, allow_unknown_extension, + allow_field_number) -def ParseLines(lines, message, allow_unknown_extension=False): +def ParseLines(lines, message, allow_unknown_extension=False, + allow_field_number=False): """Parses an text representation of a protocol message into a message. Args: @@ -317,6 +378,7 @@ def ParseLines(lines, message, allow_unknown_extension=False): message: A protocol buffer message to merge into. allow_unknown_extension: if True, skip over missing extensions and keep parsing + allow_field_number: if True, both field number and field name are allowed. Returns: The same message passed as argument. @@ -324,11 +386,12 @@ def ParseLines(lines, message, allow_unknown_extension=False): Raises: ParseError: On text parsing problems. """ - _ParseOrMerge(lines, message, False, allow_unknown_extension) - return message + parser = _Parser(allow_unknown_extension, allow_field_number) + return parser.ParseLines(lines, message) -def MergeLines(lines, message, allow_unknown_extension=False): +def MergeLines(lines, message, allow_unknown_extension=False, + allow_field_number=False): """Parses an text representation of a protocol message into a message. Args: @@ -336,6 +399,7 @@ def MergeLines(lines, message, allow_unknown_extension=False): message: A protocol buffer message to merge into. allow_unknown_extension: if True, skip over missing extensions and keep parsing + allow_field_number: if True, both field number and field name are allowed. Returns: The same message passed as argument. @@ -343,146 +407,272 @@ def MergeLines(lines, message, allow_unknown_extension=False): Raises: ParseError: On text parsing problems. """ - _ParseOrMerge(lines, message, True, allow_unknown_extension) - return message + parser = _Parser(allow_unknown_extension, allow_field_number) + return parser.MergeLines(lines, message) -def _ParseOrMerge(lines, - message, - allow_multiple_scalars, - allow_unknown_extension=False): - """Converts an text representation of a protocol message into a message. +class _Parser(object): + """Text format parser for protocol message.""" - Args: - lines: Lines of a message's text representation. - message: A protocol buffer message to merge into. - allow_multiple_scalars: Determines if repeated values for a non-repeated - field are permitted, e.g., the string "foo: 1 foo: 2" for a - required/optional field named "foo". - allow_unknown_extension: if True, skip over missing extensions and keep - parsing + def __init__(self, allow_unknown_extension=False, allow_field_number=False): + self.allow_unknown_extension = allow_unknown_extension + self.allow_field_number = allow_field_number - Raises: - ParseError: On text parsing problems. - """ - tokenizer = _Tokenizer(lines) - while not tokenizer.AtEnd(): - _MergeField(tokenizer, message, allow_multiple_scalars, - allow_unknown_extension) + def ParseFromString(self, text, message): + """Parses an text representation of a protocol message into a message.""" + if not isinstance(text, str): + text = text.decode('utf-8') + return self.ParseLines(text.split('\n'), message) + def ParseLines(self, lines, message): + """Parses an text representation of a protocol message into a message.""" + self._allow_multiple_scalars = False + self._ParseOrMerge(lines, message) + return message -def _MergeField(tokenizer, - message, - allow_multiple_scalars, - allow_unknown_extension=False): - """Merges a single protocol message field into a message. + def MergeFromString(self, text, message): + """Merges an text representation of a protocol message into a message.""" + return self._MergeLines(text.split('\n'), message) - Args: - tokenizer: A tokenizer to parse the field name and values. - message: A protocol message to record the data. - allow_multiple_scalars: Determines if repeated values for a non-repeated - field are permitted, e.g., the string "foo: 1 foo: 2" for a - required/optional field named "foo". - allow_unknown_extension: if True, skip over missing extensions and keep - parsing. + def MergeLines(self, lines, message): + """Merges an text representation of a protocol message into a message.""" + self._allow_multiple_scalars = True + self._ParseOrMerge(lines, message) + return message - Raises: - ParseError: In case of text parsing problems. - """ - message_descriptor = message.DESCRIPTOR - if (hasattr(message_descriptor, 'syntax') and - message_descriptor.syntax == 'proto3'): - # Proto3 doesn't represent presence so we can't test if multiple - # scalars have occurred. We have to allow them. - allow_multiple_scalars = True - if tokenizer.TryConsume('['): - name = [tokenizer.ConsumeIdentifier()] - while tokenizer.TryConsume('.'): - name.append(tokenizer.ConsumeIdentifier()) - name = '.'.join(name) - - if not message_descriptor.is_extendable: - raise tokenizer.ParseErrorPreviousToken( - 'Message type "%s" does not have extensions.' % - message_descriptor.full_name) - # pylint: disable=protected-access - field = message.Extensions._FindExtensionByName(name) - # pylint: enable=protected-access - if not field: - if allow_unknown_extension: - field = None - else: - raise tokenizer.ParseErrorPreviousToken( - 'Extension "%s" not registered.' % name) - elif message_descriptor != field.containing_type: - raise tokenizer.ParseErrorPreviousToken( - 'Extension "%s" does not extend message type "%s".' % ( - name, message_descriptor.full_name)) + def _ParseOrMerge(self, lines, message): + """Converts an text representation of a protocol message into a message. - tokenizer.Consume(']') + Args: + lines: Lines of a message's text representation. + message: A protocol buffer message to merge into. - else: - name = tokenizer.ConsumeIdentifier() - field = message_descriptor.fields_by_name.get(name, None) - - # Group names are expected to be capitalized as they appear in the - # .proto file, which actually matches their type names, not their field - # names. - if not field: - field = message_descriptor.fields_by_name.get(name.lower(), None) - if field and field.type != descriptor.FieldDescriptor.TYPE_GROUP: - field = None - - if (field and field.type == descriptor.FieldDescriptor.TYPE_GROUP and - field.message_type.name != name): - field = None - - if not field: - raise tokenizer.ParseErrorPreviousToken( - 'Message type "%s" has no field named "%s".' % ( - message_descriptor.full_name, name)) - - if field: - if not allow_multiple_scalars and field.containing_oneof: - # Check if there's a different field set in this oneof. - # Note that we ignore the case if the same field was set before, and we - # apply allow_multiple_scalars to non-scalar fields as well. - which_oneof = message.WhichOneof(field.containing_oneof.name) - if which_oneof is not None and which_oneof != field.name: + Raises: + ParseError: On text parsing problems. + """ + tokenizer = _Tokenizer(lines) + while not tokenizer.AtEnd(): + self._MergeField(tokenizer, message) + + def _MergeField(self, tokenizer, message): + """Merges a single protocol message field into a message. + + Args: + tokenizer: A tokenizer to parse the field name and values. + message: A protocol message to record the data. + + Raises: + ParseError: In case of text parsing problems. + """ + message_descriptor = message.DESCRIPTOR + if (hasattr(message_descriptor, 'syntax') and + message_descriptor.syntax == 'proto3'): + # Proto3 doesn't represent presence so we can't test if multiple + # scalars have occurred. We have to allow them. + self._allow_multiple_scalars = True + if tokenizer.TryConsume('['): + name = [tokenizer.ConsumeIdentifier()] + while tokenizer.TryConsume('.'): + name.append(tokenizer.ConsumeIdentifier()) + name = '.'.join(name) + + if not message_descriptor.is_extendable: + raise tokenizer.ParseErrorPreviousToken( + 'Message type "%s" does not have extensions.' % + message_descriptor.full_name) + # pylint: disable=protected-access + field = message.Extensions._FindExtensionByName(name) + # pylint: enable=protected-access + if not field: + if self.allow_unknown_extension: + field = None + else: + raise tokenizer.ParseErrorPreviousToken( + 'Extension "%s" not registered.' % name) + elif message_descriptor != field.containing_type: raise tokenizer.ParseErrorPreviousToken( - 'Field "%s" is specified along with field "%s", another member of ' - 'oneof "%s" for message type "%s".' % ( - field.name, which_oneof, field.containing_oneof.name, - message_descriptor.full_name)) + 'Extension "%s" does not extend message type "%s".' % ( + name, message_descriptor.full_name)) + + tokenizer.Consume(']') - if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: - tokenizer.TryConsume(':') - merger = _MergeMessageField else: - tokenizer.Consume(':') - merger = _MergeScalarField - - if (field.label == descriptor.FieldDescriptor.LABEL_REPEATED - and tokenizer.TryConsume('[')): - # Short repeated format, e.g. "foo: [1, 2, 3]" - while True: - merger(tokenizer, message, field, allow_multiple_scalars, - allow_unknown_extension) - if tokenizer.TryConsume(']'): break - tokenizer.Consume(',') + name = tokenizer.ConsumeIdentifier() + if self.allow_field_number and name.isdigit(): + number = ParseInteger(name, True, True) + field = message_descriptor.fields_by_number.get(number, None) + if not field and message_descriptor.is_extendable: + field = message.Extensions._FindExtensionByNumber(number) + else: + field = message_descriptor.fields_by_name.get(name, None) + + # Group names are expected to be capitalized as they appear in the + # .proto file, which actually matches their type names, not their field + # names. + if not field: + field = message_descriptor.fields_by_name.get(name.lower(), None) + if field and field.type != descriptor.FieldDescriptor.TYPE_GROUP: + field = None + + if (field and field.type == descriptor.FieldDescriptor.TYPE_GROUP and + field.message_type.name != name): + field = None + + if not field: + raise tokenizer.ParseErrorPreviousToken( + 'Message type "%s" has no field named "%s".' % ( + message_descriptor.full_name, name)) + + if field: + if not self._allow_multiple_scalars and field.containing_oneof: + # Check if there's a different field set in this oneof. + # Note that we ignore the case if the same field was set before, and we + # apply _allow_multiple_scalars to non-scalar fields as well. + which_oneof = message.WhichOneof(field.containing_oneof.name) + if which_oneof is not None and which_oneof != field.name: + raise tokenizer.ParseErrorPreviousToken( + 'Field "%s" is specified along with field "%s", another member ' + 'of oneof "%s" for message type "%s".' % ( + field.name, which_oneof, field.containing_oneof.name, + message_descriptor.full_name)) + + if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + tokenizer.TryConsume(':') + merger = self._MergeMessageField + else: + tokenizer.Consume(':') + merger = self._MergeScalarField + + if (field.label == descriptor.FieldDescriptor.LABEL_REPEATED + and tokenizer.TryConsume('[')): + # Short repeated format, e.g. "foo: [1, 2, 3]" + while True: + merger(tokenizer, message, field) + if tokenizer.TryConsume(']'): break + tokenizer.Consume(',') + + else: + merger(tokenizer, message, field) + + else: # Proto field is unknown. + assert self.allow_unknown_extension + _SkipFieldContents(tokenizer) + + # For historical reasons, fields may optionally be separated by commas or + # semicolons. + if not tokenizer.TryConsume(','): + tokenizer.TryConsume(';') + + def _MergeMessageField(self, tokenizer, message, field): + """Merges a single scalar field into a message. + + Args: + tokenizer: A tokenizer to parse the field value. + message: The message of which field is a member. + field: The descriptor of the field to be merged. + + Raises: + ParseError: In case of text parsing problems. + """ + is_map_entry = _IsMapEntry(field) + if tokenizer.TryConsume('<'): + end_token = '>' else: - merger(tokenizer, message, field, allow_multiple_scalars, - allow_unknown_extension) + tokenizer.Consume('{') + end_token = '}' + + if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + if field.is_extension: + sub_message = message.Extensions[field].add() + elif is_map_entry: + # pylint: disable=protected-access + sub_message = field.message_type._concrete_class() + else: + sub_message = getattr(message, field.name).add() + else: + if field.is_extension: + sub_message = message.Extensions[field] + else: + sub_message = getattr(message, field.name) + sub_message.SetInParent() + + while not tokenizer.TryConsume(end_token): + if tokenizer.AtEnd(): + raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % (end_token,)) + self._MergeField(tokenizer, sub_message) + + if is_map_entry: + value_cpptype = field.message_type.fields_by_name['value'].cpp_type + if value_cpptype == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + value = getattr(message, field.name)[sub_message.key] + value.MergeFrom(sub_message.value) + else: + getattr(message, field.name)[sub_message.key] = sub_message.value - else: # Proto field is unknown. - assert allow_unknown_extension - _SkipFieldContents(tokenizer) + def _MergeScalarField(self, tokenizer, message, field): + """Merges a single scalar field into a message. - # For historical reasons, fields may optionally be separated by commas or - # semicolons. - if not tokenizer.TryConsume(','): - tokenizer.TryConsume(';') + Args: + tokenizer: A tokenizer to parse the field value. + message: A protocol message to record the data. + field: The descriptor of the field to be merged. + + Raises: + ParseError: In case of text parsing problems. + RuntimeError: On runtime errors. + """ + _ = self.allow_unknown_extension + value = None + + if field.type in (descriptor.FieldDescriptor.TYPE_INT32, + descriptor.FieldDescriptor.TYPE_SINT32, + descriptor.FieldDescriptor.TYPE_SFIXED32): + value = tokenizer.ConsumeInt32() + elif field.type in (descriptor.FieldDescriptor.TYPE_INT64, + descriptor.FieldDescriptor.TYPE_SINT64, + descriptor.FieldDescriptor.TYPE_SFIXED64): + value = tokenizer.ConsumeInt64() + elif field.type in (descriptor.FieldDescriptor.TYPE_UINT32, + descriptor.FieldDescriptor.TYPE_FIXED32): + value = tokenizer.ConsumeUint32() + elif field.type in (descriptor.FieldDescriptor.TYPE_UINT64, + descriptor.FieldDescriptor.TYPE_FIXED64): + value = tokenizer.ConsumeUint64() + elif field.type in (descriptor.FieldDescriptor.TYPE_FLOAT, + descriptor.FieldDescriptor.TYPE_DOUBLE): + value = tokenizer.ConsumeFloat() + elif field.type == descriptor.FieldDescriptor.TYPE_BOOL: + value = tokenizer.ConsumeBool() + elif field.type == descriptor.FieldDescriptor.TYPE_STRING: + value = tokenizer.ConsumeString() + elif field.type == descriptor.FieldDescriptor.TYPE_BYTES: + value = tokenizer.ConsumeByteString() + elif field.type == descriptor.FieldDescriptor.TYPE_ENUM: + value = tokenizer.ConsumeEnum(field) + else: + raise RuntimeError('Unknown field type %d' % field.type) + + if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + if field.is_extension: + message.Extensions[field].append(value) + else: + getattr(message, field.name).append(value) + else: + if field.is_extension: + if not self._allow_multiple_scalars and message.HasExtension(field): + raise tokenizer.ParseErrorPreviousToken( + 'Message type "%s" should not have multiple "%s" extensions.' % + (message.DESCRIPTOR.full_name, field.full_name)) + else: + message.Extensions[field] = value + else: + if not self._allow_multiple_scalars and message.HasField(field.name): + raise tokenizer.ParseErrorPreviousToken( + 'Message type "%s" should not have multiple "%s" fields.' % + (message.DESCRIPTOR.full_name, field.name)) + else: + setattr(message, field.name, value) def _SkipFieldContents(tokenizer): @@ -555,10 +745,10 @@ def _SkipFieldValue(tokenizer): Raises: ParseError: In case an invalid field value is found. """ - # String tokens can come in multiple adjacent string literals. + # String/bytes tokens can come in multiple adjacent string literals. # If we can consume one, consume as many as we can. - if tokenizer.TryConsumeString(): - while tokenizer.TryConsumeString(): + if tokenizer.TryConsumeByteString(): + while tokenizer.TryConsumeByteString(): pass return @@ -569,132 +759,6 @@ def _SkipFieldValue(tokenizer): raise ParseError('Invalid field value: ' + tokenizer.token) -def _MergeMessageField(tokenizer, message, field, allow_multiple_scalars, - allow_unknown_extension): - """Merges a single scalar field into a message. - - Args: - tokenizer: A tokenizer to parse the field value. - message: The message of which field is a member. - field: The descriptor of the field to be merged. - allow_multiple_scalars: Determines if repeated values for a non-repeated - field are permitted, e.g., the string "foo: 1 foo: 2" for a - required/optional field named "foo". - allow_unknown_extension: if True, skip over missing extensions and keep - parsing. - - Raises: - ParseError: In case of text parsing problems. - """ - is_map_entry = _IsMapEntry(field) - - if tokenizer.TryConsume('<'): - end_token = '>' - else: - tokenizer.Consume('{') - end_token = '}' - - if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: - if field.is_extension: - sub_message = message.Extensions[field].add() - elif is_map_entry: - # pylint: disable=protected-access - sub_message = field.message_type._concrete_class() - else: - sub_message = getattr(message, field.name).add() - else: - if field.is_extension: - sub_message = message.Extensions[field] - else: - sub_message = getattr(message, field.name) - sub_message.SetInParent() - - while not tokenizer.TryConsume(end_token): - if tokenizer.AtEnd(): - raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % (end_token,)) - _MergeField(tokenizer, sub_message, allow_multiple_scalars, - allow_unknown_extension) - - if is_map_entry: - value_cpptype = field.message_type.fields_by_name['value'].cpp_type - if value_cpptype == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: - value = getattr(message, field.name)[sub_message.key] - value.MergeFrom(sub_message.value) - else: - getattr(message, field.name)[sub_message.key] = sub_message.value - - -def _MergeScalarField(tokenizer, message, field, allow_multiple_scalars, - allow_unknown_extension): - """Merges a single scalar field into a message. - - Args: - tokenizer: A tokenizer to parse the field value. - message: A protocol message to record the data. - field: The descriptor of the field to be merged. - allow_multiple_scalars: Determines if repeated values for a non-repeated - field are permitted, e.g., the string "foo: 1 foo: 2" for a - required/optional field named "foo". - allow_unknown_extension: Unused, just here for consistency with - _MergeMessageField. - - Raises: - ParseError: In case of text parsing problems. - RuntimeError: On runtime errors. - """ - _ = allow_unknown_extension - value = None - - if field.type in (descriptor.FieldDescriptor.TYPE_INT32, - descriptor.FieldDescriptor.TYPE_SINT32, - descriptor.FieldDescriptor.TYPE_SFIXED32): - value = tokenizer.ConsumeInt32() - elif field.type in (descriptor.FieldDescriptor.TYPE_INT64, - descriptor.FieldDescriptor.TYPE_SINT64, - descriptor.FieldDescriptor.TYPE_SFIXED64): - value = tokenizer.ConsumeInt64() - elif field.type in (descriptor.FieldDescriptor.TYPE_UINT32, - descriptor.FieldDescriptor.TYPE_FIXED32): - value = tokenizer.ConsumeUint32() - elif field.type in (descriptor.FieldDescriptor.TYPE_UINT64, - descriptor.FieldDescriptor.TYPE_FIXED64): - value = tokenizer.ConsumeUint64() - elif field.type in (descriptor.FieldDescriptor.TYPE_FLOAT, - descriptor.FieldDescriptor.TYPE_DOUBLE): - value = tokenizer.ConsumeFloat() - elif field.type == descriptor.FieldDescriptor.TYPE_BOOL: - value = tokenizer.ConsumeBool() - elif field.type == descriptor.FieldDescriptor.TYPE_STRING: - value = tokenizer.ConsumeString() - elif field.type == descriptor.FieldDescriptor.TYPE_BYTES: - value = tokenizer.ConsumeByteString() - elif field.type == descriptor.FieldDescriptor.TYPE_ENUM: - value = tokenizer.ConsumeEnum(field) - else: - raise RuntimeError('Unknown field type %d' % field.type) - - if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: - if field.is_extension: - message.Extensions[field].append(value) - else: - getattr(message, field.name).append(value) - else: - if field.is_extension: - if not allow_multiple_scalars and message.HasExtension(field): - raise tokenizer.ParseErrorPreviousToken( - 'Message type "%s" should not have multiple "%s" extensions.' % - (message.DESCRIPTOR.full_name, field.full_name)) - else: - message.Extensions[field] = value - else: - if not allow_multiple_scalars and message.HasField(field.name): - raise tokenizer.ParseErrorPreviousToken( - 'Message type "%s" should not have multiple "%s" fields.' % - (message.DESCRIPTOR.full_name, field.name)) - else: - setattr(message, field.name, value) - - class _Tokenizer(object): """Protocol buffer text representation tokenizer. @@ -925,9 +989,9 @@ class _Tokenizer(object): self.NextToken() return result - def TryConsumeString(self): + def TryConsumeByteString(self): try: - self.ConsumeString() + self.ConsumeByteString() return True except ParseError: return False -- cgit v1.2.3