diff options
Diffstat (limited to 'python/google/protobuf/internal/well_known_types.py')
-rw-r--r-- | python/google/protobuf/internal/well_known_types.py | 159 |
1 files changed, 139 insertions, 20 deletions
diff --git a/python/google/protobuf/internal/well_known_types.py b/python/google/protobuf/internal/well_known_types.py index d35fcc5f..37a65cfa 100644 --- a/python/google/protobuf/internal/well_known_types.py +++ b/python/google/protobuf/internal/well_known_types.py @@ -40,6 +40,7 @@ This files defines well known classes which need extra maintenance including: __author__ = 'jieluo@google.com (Jie Luo)' +import collections from datetime import datetime from datetime import timedelta import six @@ -53,6 +54,7 @@ _NANOS_PER_MICROSECOND = 1000 _MILLIS_PER_SECOND = 1000 _MICROS_PER_SECOND = 1000000 _SECONDS_PER_DAY = 24 * 3600 +_DURATION_SECONDS_MAX = 315576000000 class Error(Exception): @@ -66,13 +68,14 @@ class ParseError(Error): class Any(object): """Class for Any Message type.""" - def Pack(self, msg, type_url_prefix='type.googleapis.com/'): + def Pack(self, msg, type_url_prefix='type.googleapis.com/', + deterministic=None): """Packs the specified message into current Any message.""" if len(type_url_prefix) < 1 or type_url_prefix[-1] != '/': self.type_url = '%s/%s' % (type_url_prefix, msg.DESCRIPTOR.full_name) else: self.type_url = '%s%s' % (type_url_prefix, msg.DESCRIPTOR.full_name) - self.value = msg.SerializeToString() + self.value = msg.SerializeToString(deterministic=deterministic) def Unpack(self, msg): """Unpacks the current Any message into specified message.""" @@ -82,10 +85,14 @@ class Any(object): msg.ParseFromString(self.value) return True + def TypeName(self): + """Returns the protobuf type name of the inner message.""" + # Only last part is to be used: b/25630112 + return self.type_url.split('/')[-1] + def Is(self, descriptor): """Checks if this Any represents the given protobuf type.""" - # Only last part is to be used: b/25630112 - return self.type_url.split('/')[-1] == descriptor.full_name + return self.TypeName() == descriptor.full_name class Timestamp(object): @@ -243,6 +250,7 @@ class Duration(object): represent the exact Duration value. For example: "1s", "1.010s", "1.000000100s", "-3.100s" """ + _CheckDurationValid(self.seconds, self.nanos) if self.seconds < 0 or self.nanos < 0: result = '-' seconds = - self.seconds + int((0 - self.nanos) // 1e9) @@ -282,14 +290,17 @@ class Duration(object): try: pos = value.find('.') if pos == -1: - self.seconds = int(value[:-1]) - self.nanos = 0 + seconds = int(value[:-1]) + nanos = 0 else: - self.seconds = int(value[:pos]) + seconds = int(value[:pos]) if value[0] == '-': - self.nanos = int(round(float('-0{0}'.format(value[pos: -1])) *1e9)) + nanos = int(round(float('-0{0}'.format(value[pos: -1])) *1e9)) else: - self.nanos = int(round(float('0{0}'.format(value[pos: -1])) *1e9)) + nanos = int(round(float('0{0}'.format(value[pos: -1])) *1e9)) + _CheckDurationValid(seconds, nanos) + self.seconds = seconds + self.nanos = nanos except ValueError: raise ParseError( 'Couldn\'t parse duration: {0}.'.format(value)) @@ -341,12 +352,12 @@ class Duration(object): self.nanos, _NANOS_PER_MICROSECOND)) def FromTimedelta(self, td): - """Convertd timedelta to Duration.""" + """Converts timedelta to Duration.""" self._NormalizeDuration(td.seconds + td.days * _SECONDS_PER_DAY, td.microseconds * _NANOS_PER_MICROSECOND) def _NormalizeDuration(self, seconds, nanos): - """Set Duration by seconds and nonas.""" + """Set Duration by seconds and nanos.""" # Force nanos to be negative if the duration is negative. if seconds < 0 and nanos > 0: seconds += 1 @@ -355,6 +366,20 @@ class Duration(object): self.nanos = nanos +def _CheckDurationValid(seconds, nanos): + if seconds < -_DURATION_SECONDS_MAX or seconds > _DURATION_SECONDS_MAX: + raise Error( + 'Duration is not valid: Seconds {0} must be in range ' + '[-315576000000, 315576000000].'.format(seconds)) + if nanos <= -_NANOS_PER_SECOND or nanos >= _NANOS_PER_SECOND: + raise Error( + 'Duration is not valid: Nanos {0} must be in range ' + '[-999999999, 999999999].'.format(nanos)) + if (nanos < 0 and seconds > 0) or (nanos > 0 and seconds < 0): + raise Error( + 'Duration is not valid: Sign mismatch.') + + def _RoundTowardZero(value, divider): """Truncates the remainder part after division.""" # For some languanges, the sign of the remainder is implementation @@ -375,13 +400,16 @@ class FieldMask(object): def ToJsonString(self): """Converts FieldMask to string according to proto3 JSON spec.""" - return ','.join(self.paths) + camelcase_paths = [] + for path in self.paths: + camelcase_paths.append(_SnakeCaseToCamelCase(path)) + return ','.join(camelcase_paths) def FromJsonString(self, value): """Converts string to FieldMask according to proto3 JSON spec.""" self.Clear() for path in value.split(','): - self.paths.append(path) + self.paths.append(_CamelCaseToSnakeCase(path)) def IsValidForDescriptor(self, message_descriptor): """Checks whether the FieldMask is valid for Message Descriptor.""" @@ -450,7 +478,7 @@ def _IsValidPath(message_descriptor, path): parts = path.split('.') last = parts.pop() for name in parts: - field = message_descriptor.fields_by_name[name] + field = message_descriptor.fields_by_name.get(name) if (field is None or field.label == FieldDescriptor.LABEL_REPEATED or field.type != FieldDescriptor.TYPE_MESSAGE): @@ -468,6 +496,48 @@ def _CheckFieldMaskMessage(message): message_descriptor.full_name)) +def _SnakeCaseToCamelCase(path_name): + """Converts a path name from snake_case to camelCase.""" + result = [] + after_underscore = False + for c in path_name: + if c.isupper(): + raise Error('Fail to print FieldMask to Json string: Path name ' + '{0} must not contain uppercase letters.'.format(path_name)) + if after_underscore: + if c.islower(): + result.append(c.upper()) + after_underscore = False + else: + raise Error('Fail to print FieldMask to Json string: The ' + 'character after a "_" must be a lowercase letter ' + 'in path name {0}.'.format(path_name)) + elif c == '_': + after_underscore = True + else: + result += c + + if after_underscore: + raise Error('Fail to print FieldMask to Json string: Trailing "_" ' + 'in path name {0}.'.format(path_name)) + return ''.join(result) + + +def _CamelCaseToSnakeCase(path_name): + """Converts a field name from camelCase to snake_case.""" + result = [] + for c in path_name: + if c == '_': + raise ParseError('Fail to parse FieldMask: Path name ' + '{0} must not contain "_"s.'.format(path_name)) + if c.isupper(): + result += '_' + result += c.lower() + else: + result += c + return ''.join(result) + + class _FieldMaskTree(object): """Represents a FieldMask in a tree structure. @@ -582,9 +652,10 @@ def _MergeMessage( raise ValueError('Error: Field {0} in message {1} is not a singular ' 'message field and cannot have sub-fields.'.format( name, source_descriptor.full_name)) - _MergeMessage( - child, getattr(source, name), getattr(destination, name), - replace_message, replace_repeated) + if source.HasField(name): + _MergeMessage( + child, getattr(source, name), getattr(destination, name), + replace_message, replace_repeated) continue if field.label == FieldDescriptor.LABEL_REPEATED: if replace_repeated: @@ -633,6 +704,12 @@ def _SetStructValue(struct_value, value): struct_value.string_value = value elif isinstance(value, _INT_OR_FLOAT): struct_value.number_value = value + elif isinstance(value, dict): + struct_value.struct_value.Clear() + struct_value.struct_value.update(value) + elif isinstance(value, list): + struct_value.list_value.Clear() + struct_value.list_value.extend(value) else: raise ValueError('Unexpected type') @@ -663,18 +740,49 @@ class Struct(object): def __getitem__(self, key): return _GetStructValue(self.fields[key]) + def __contains__(self, item): + return item in self.fields + def __setitem__(self, key, value): _SetStructValue(self.fields[key], value) + def __delitem__(self, key): + del self.fields[key] + + def __len__(self): + return len(self.fields) + + def __iter__(self): + return iter(self.fields) + + def keys(self): # pylint: disable=invalid-name + return self.fields.keys() + + def values(self): # pylint: disable=invalid-name + return [self[key] for key in self] + + def items(self): # pylint: disable=invalid-name + return [(key, self[key]) for key in self] + def get_or_create_list(self, key): """Returns a list for this key, creating if it didn't exist already.""" + if not self.fields[key].HasField('list_value'): + # Clear will mark list_value modified which will indeed create a list. + self.fields[key].list_value.Clear() return self.fields[key].list_value def get_or_create_struct(self, key): """Returns a struct for this key, creating if it didn't exist already.""" + if not self.fields[key].HasField('struct_value'): + # Clear will mark struct_value modified which will indeed create a struct. + self.fields[key].struct_value.Clear() return self.fields[key].struct_value - # TODO(haberman): allow constructing/merging from dict. + def update(self, dictionary): # pylint: disable=invalid-name + for key, value in dictionary.items(): + _SetStructValue(self.fields[key], value) + +collections.MutableMapping.register(Struct) class ListValue(object): @@ -697,17 +805,28 @@ class ListValue(object): def __setitem__(self, index, value): _SetStructValue(self.values.__getitem__(index), value) + def __delitem__(self, key): + del self.values[key] + def items(self): for i in range(len(self)): yield self[i] def add_struct(self): """Appends and returns a struct value as the next value in the list.""" - return self.values.add().struct_value + struct_value = self.values.add().struct_value + # Clear will mark struct_value modified which will indeed create a struct. + struct_value.Clear() + return struct_value def add_list(self): """Appends and returns a list value as the next value in the list.""" - return self.values.add().list_value + list_value = self.values.add().list_value + # Clear will mark list_value modified which will indeed create a list. + list_value.Clear() + return list_value + +collections.MutableSequence.register(ListValue) WKTBASES = { |