diff options
Diffstat (limited to 'python/google/protobuf/text_format.py')
-rwxr-xr-x | python/google/protobuf/text_format.py | 137 |
1 files changed, 90 insertions, 47 deletions
diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py index 8d256076..a6f41ca8 100755 --- a/python/google/protobuf/text_format.py +++ b/python/google/protobuf/text_format.py @@ -384,7 +384,7 @@ def _MergeField(tokenizer, field are permitted, e.g., the string "foo: 1 foo: 2" for a required/optional field named "foo". allow_unknown_extension: if True, skip over missing extensions and keep - parsing + parsing. Raises: ParseError: In case of text parsing problems. @@ -442,55 +442,39 @@ def _MergeField(tokenizer, 'Message type "%s" has no field named "%s".' % ( message_descriptor.full_name, name)) - if field and field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: - is_map_entry = _IsMapEntry(field) - tokenizer.TryConsume(':') - - if tokenizer.TryConsume('<'): - end_token = '>' - else: - tokenizer.Consume('{') - end_token = '}' - - if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: - if field.is_extension: - sub_message = message.Extensions[field].add() - elif is_map_entry: - sub_message = field.message_type._concrete_class() - else: - sub_message = getattr(message, field.name).add() + if field: + if not allow_multiple_scalars and field.containing_oneof: + # Check if there's a different field set in this oneof. + # Note that we ignore the case if the same field was set before, and we + # apply allow_multiple_scalars to non-scalar fields as well. + which_oneof = message.WhichOneof(field.containing_oneof.name) + if which_oneof is not None and which_oneof != field.name: + raise tokenizer.ParseErrorPreviousToken( + 'Field "%s" is specified along with field "%s", another member of ' + 'oneof "%s" for message type "%s".' % ( + field.name, which_oneof, field.containing_oneof.name, + message_descriptor.full_name)) + + if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + tokenizer.TryConsume(':') + merger = _MergeMessageField else: - if field.is_extension: - sub_message = message.Extensions[field] - else: - sub_message = getattr(message, field.name) - sub_message.SetInParent() - - while not tokenizer.TryConsume(end_token): - if tokenizer.AtEnd(): - raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % (end_token)) - _MergeField(tokenizer, sub_message, allow_multiple_scalars, - allow_unknown_extension) - - if is_map_entry: - value_cpptype = field.message_type.fields_by_name['value'].cpp_type - if value_cpptype == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: - value = getattr(message, field.name)[sub_message.key] - value.MergeFrom(sub_message.value) - else: - getattr(message, field.name)[sub_message.key] = sub_message.value - elif field: - tokenizer.Consume(':') - if (field.label == descriptor.FieldDescriptor.LABEL_REPEATED and - tokenizer.TryConsume('[')): + tokenizer.Consume(':') + merger = _MergeScalarField + + if (field.label == descriptor.FieldDescriptor.LABEL_REPEATED + and tokenizer.TryConsume('[')): # Short repeated format, e.g. "foo: [1, 2, 3]" while True: - _MergeScalarField(tokenizer, message, field, allow_multiple_scalars) - if tokenizer.TryConsume(']'): - break + merger(tokenizer, message, field, allow_multiple_scalars, + allow_unknown_extension) + if tokenizer.TryConsume(']'): break tokenizer.Consume(',') + else: - _MergeScalarField(tokenizer, message, field, allow_multiple_scalars) + merger(tokenizer, message, field, allow_multiple_scalars, + allow_unknown_extension) + else: # Proto field is unknown. assert allow_unknown_extension _SkipFieldContents(tokenizer) @@ -585,8 +569,64 @@ def _SkipFieldValue(tokenizer): raise ParseError('Invalid field value: ' + tokenizer.token) -def _MergeScalarField(tokenizer, message, field, allow_multiple_scalars): - """Merges a single protocol message scalar field into a message. +def _MergeMessageField(tokenizer, message, field, allow_multiple_scalars, + allow_unknown_extension): + """Merges a single scalar field into a message. + + Args: + tokenizer: A tokenizer to parse the field value. + message: The message of which field is a member. + field: The descriptor of the field to be merged. + allow_multiple_scalars: Determines if repeated values for a non-repeated + field are permitted, e.g., the string "foo: 1 foo: 2" for a + required/optional field named "foo". + allow_unknown_extension: if True, skip over missing extensions and keep + parsing. + + Raises: + ParseError: In case of text parsing problems. + """ + is_map_entry = _IsMapEntry(field) + + if tokenizer.TryConsume('<'): + end_token = '>' + else: + tokenizer.Consume('{') + end_token = '}' + + if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + if field.is_extension: + sub_message = message.Extensions[field].add() + elif is_map_entry: + # pylint: disable=protected-access + sub_message = field.message_type._concrete_class() + else: + sub_message = getattr(message, field.name).add() + else: + if field.is_extension: + sub_message = message.Extensions[field] + else: + sub_message = getattr(message, field.name) + sub_message.SetInParent() + + while not tokenizer.TryConsume(end_token): + if tokenizer.AtEnd(): + raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % (end_token,)) + _MergeField(tokenizer, sub_message, allow_multiple_scalars, + allow_unknown_extension) + + if is_map_entry: + value_cpptype = field.message_type.fields_by_name['value'].cpp_type + if value_cpptype == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + value = getattr(message, field.name)[sub_message.key] + value.MergeFrom(sub_message.value) + else: + getattr(message, field.name)[sub_message.key] = sub_message.value + + +def _MergeScalarField(tokenizer, message, field, allow_multiple_scalars, + allow_unknown_extension): + """Merges a single scalar field into a message. Args: tokenizer: A tokenizer to parse the field value. @@ -595,11 +635,14 @@ def _MergeScalarField(tokenizer, message, field, allow_multiple_scalars): allow_multiple_scalars: Determines if repeated values for a non-repeated field are permitted, e.g., the string "foo: 1 foo: 2" for a required/optional field named "foo". + allow_unknown_extension: Unused, just here for consistency with + _MergeMessageField. Raises: ParseError: In case of text parsing problems. RuntimeError: On runtime errors. """ + _ = allow_unknown_extension value = None if field.type in (descriptor.FieldDescriptor.TYPE_INT32, |