aboutsummaryrefslogtreecommitdiffhomepage
path: root/python/google/protobuf/internal/well_known_types.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/google/protobuf/internal/well_known_types.py')
-rw-r--r--python/google/protobuf/internal/well_known_types.py159
1 files changed, 139 insertions, 20 deletions
diff --git a/python/google/protobuf/internal/well_known_types.py b/python/google/protobuf/internal/well_known_types.py
index d35fcc5f..37a65cfa 100644
--- a/python/google/protobuf/internal/well_known_types.py
+++ b/python/google/protobuf/internal/well_known_types.py
@@ -40,6 +40,7 @@ This files defines well known classes which need extra maintenance including:
__author__ = 'jieluo@google.com (Jie Luo)'
+import collections
from datetime import datetime
from datetime import timedelta
import six
@@ -53,6 +54,7 @@ _NANOS_PER_MICROSECOND = 1000
_MILLIS_PER_SECOND = 1000
_MICROS_PER_SECOND = 1000000
_SECONDS_PER_DAY = 24 * 3600
+_DURATION_SECONDS_MAX = 315576000000
class Error(Exception):
@@ -66,13 +68,14 @@ class ParseError(Error):
class Any(object):
"""Class for Any Message type."""
- def Pack(self, msg, type_url_prefix='type.googleapis.com/'):
+ def Pack(self, msg, type_url_prefix='type.googleapis.com/',
+ deterministic=None):
"""Packs the specified message into current Any message."""
if len(type_url_prefix) < 1 or type_url_prefix[-1] != '/':
self.type_url = '%s/%s' % (type_url_prefix, msg.DESCRIPTOR.full_name)
else:
self.type_url = '%s%s' % (type_url_prefix, msg.DESCRIPTOR.full_name)
- self.value = msg.SerializeToString()
+ self.value = msg.SerializeToString(deterministic=deterministic)
def Unpack(self, msg):
"""Unpacks the current Any message into specified message."""
@@ -82,10 +85,14 @@ class Any(object):
msg.ParseFromString(self.value)
return True
+ def TypeName(self):
+ """Returns the protobuf type name of the inner message."""
+ # Only last part is to be used: b/25630112
+ return self.type_url.split('/')[-1]
+
def Is(self, descriptor):
"""Checks if this Any represents the given protobuf type."""
- # Only last part is to be used: b/25630112
- return self.type_url.split('/')[-1] == descriptor.full_name
+ return self.TypeName() == descriptor.full_name
class Timestamp(object):
@@ -243,6 +250,7 @@ class Duration(object):
represent the exact Duration value. For example: "1s", "1.010s",
"1.000000100s", "-3.100s"
"""
+ _CheckDurationValid(self.seconds, self.nanos)
if self.seconds < 0 or self.nanos < 0:
result = '-'
seconds = - self.seconds + int((0 - self.nanos) // 1e9)
@@ -282,14 +290,17 @@ class Duration(object):
try:
pos = value.find('.')
if pos == -1:
- self.seconds = int(value[:-1])
- self.nanos = 0
+ seconds = int(value[:-1])
+ nanos = 0
else:
- self.seconds = int(value[:pos])
+ seconds = int(value[:pos])
if value[0] == '-':
- self.nanos = int(round(float('-0{0}'.format(value[pos: -1])) *1e9))
+ nanos = int(round(float('-0{0}'.format(value[pos: -1])) *1e9))
else:
- self.nanos = int(round(float('0{0}'.format(value[pos: -1])) *1e9))
+ nanos = int(round(float('0{0}'.format(value[pos: -1])) *1e9))
+ _CheckDurationValid(seconds, nanos)
+ self.seconds = seconds
+ self.nanos = nanos
except ValueError:
raise ParseError(
'Couldn\'t parse duration: {0}.'.format(value))
@@ -341,12 +352,12 @@ class Duration(object):
self.nanos, _NANOS_PER_MICROSECOND))
def FromTimedelta(self, td):
- """Convertd timedelta to Duration."""
+ """Converts timedelta to Duration."""
self._NormalizeDuration(td.seconds + td.days * _SECONDS_PER_DAY,
td.microseconds * _NANOS_PER_MICROSECOND)
def _NormalizeDuration(self, seconds, nanos):
- """Set Duration by seconds and nonas."""
+ """Set Duration by seconds and nanos."""
# Force nanos to be negative if the duration is negative.
if seconds < 0 and nanos > 0:
seconds += 1
@@ -355,6 +366,20 @@ class Duration(object):
self.nanos = nanos
+def _CheckDurationValid(seconds, nanos):
+ if seconds < -_DURATION_SECONDS_MAX or seconds > _DURATION_SECONDS_MAX:
+ raise Error(
+ 'Duration is not valid: Seconds {0} must be in range '
+ '[-315576000000, 315576000000].'.format(seconds))
+ if nanos <= -_NANOS_PER_SECOND or nanos >= _NANOS_PER_SECOND:
+ raise Error(
+ 'Duration is not valid: Nanos {0} must be in range '
+ '[-999999999, 999999999].'.format(nanos))
+ if (nanos < 0 and seconds > 0) or (nanos > 0 and seconds < 0):
+ raise Error(
+ 'Duration is not valid: Sign mismatch.')
+
+
def _RoundTowardZero(value, divider):
"""Truncates the remainder part after division."""
# For some languanges, the sign of the remainder is implementation
@@ -375,13 +400,16 @@ class FieldMask(object):
def ToJsonString(self):
"""Converts FieldMask to string according to proto3 JSON spec."""
- return ','.join(self.paths)
+ camelcase_paths = []
+ for path in self.paths:
+ camelcase_paths.append(_SnakeCaseToCamelCase(path))
+ return ','.join(camelcase_paths)
def FromJsonString(self, value):
"""Converts string to FieldMask according to proto3 JSON spec."""
self.Clear()
for path in value.split(','):
- self.paths.append(path)
+ self.paths.append(_CamelCaseToSnakeCase(path))
def IsValidForDescriptor(self, message_descriptor):
"""Checks whether the FieldMask is valid for Message Descriptor."""
@@ -450,7 +478,7 @@ def _IsValidPath(message_descriptor, path):
parts = path.split('.')
last = parts.pop()
for name in parts:
- field = message_descriptor.fields_by_name[name]
+ field = message_descriptor.fields_by_name.get(name)
if (field is None or
field.label == FieldDescriptor.LABEL_REPEATED or
field.type != FieldDescriptor.TYPE_MESSAGE):
@@ -468,6 +496,48 @@ def _CheckFieldMaskMessage(message):
message_descriptor.full_name))
+def _SnakeCaseToCamelCase(path_name):
+ """Converts a path name from snake_case to camelCase."""
+ result = []
+ after_underscore = False
+ for c in path_name:
+ if c.isupper():
+ raise Error('Fail to print FieldMask to Json string: Path name '
+ '{0} must not contain uppercase letters.'.format(path_name))
+ if after_underscore:
+ if c.islower():
+ result.append(c.upper())
+ after_underscore = False
+ else:
+ raise Error('Fail to print FieldMask to Json string: The '
+ 'character after a "_" must be a lowercase letter '
+ 'in path name {0}.'.format(path_name))
+ elif c == '_':
+ after_underscore = True
+ else:
+ result += c
+
+ if after_underscore:
+ raise Error('Fail to print FieldMask to Json string: Trailing "_" '
+ 'in path name {0}.'.format(path_name))
+ return ''.join(result)
+
+
+def _CamelCaseToSnakeCase(path_name):
+ """Converts a field name from camelCase to snake_case."""
+ result = []
+ for c in path_name:
+ if c == '_':
+ raise ParseError('Fail to parse FieldMask: Path name '
+ '{0} must not contain "_"s.'.format(path_name))
+ if c.isupper():
+ result += '_'
+ result += c.lower()
+ else:
+ result += c
+ return ''.join(result)
+
+
class _FieldMaskTree(object):
"""Represents a FieldMask in a tree structure.
@@ -582,9 +652,10 @@ def _MergeMessage(
raise ValueError('Error: Field {0} in message {1} is not a singular '
'message field and cannot have sub-fields.'.format(
name, source_descriptor.full_name))
- _MergeMessage(
- child, getattr(source, name), getattr(destination, name),
- replace_message, replace_repeated)
+ if source.HasField(name):
+ _MergeMessage(
+ child, getattr(source, name), getattr(destination, name),
+ replace_message, replace_repeated)
continue
if field.label == FieldDescriptor.LABEL_REPEATED:
if replace_repeated:
@@ -633,6 +704,12 @@ def _SetStructValue(struct_value, value):
struct_value.string_value = value
elif isinstance(value, _INT_OR_FLOAT):
struct_value.number_value = value
+ elif isinstance(value, dict):
+ struct_value.struct_value.Clear()
+ struct_value.struct_value.update(value)
+ elif isinstance(value, list):
+ struct_value.list_value.Clear()
+ struct_value.list_value.extend(value)
else:
raise ValueError('Unexpected type')
@@ -663,18 +740,49 @@ class Struct(object):
def __getitem__(self, key):
return _GetStructValue(self.fields[key])
+ def __contains__(self, item):
+ return item in self.fields
+
def __setitem__(self, key, value):
_SetStructValue(self.fields[key], value)
+ def __delitem__(self, key):
+ del self.fields[key]
+
+ def __len__(self):
+ return len(self.fields)
+
+ def __iter__(self):
+ return iter(self.fields)
+
+ def keys(self): # pylint: disable=invalid-name
+ return self.fields.keys()
+
+ def values(self): # pylint: disable=invalid-name
+ return [self[key] for key in self]
+
+ def items(self): # pylint: disable=invalid-name
+ return [(key, self[key]) for key in self]
+
def get_or_create_list(self, key):
"""Returns a list for this key, creating if it didn't exist already."""
+ if not self.fields[key].HasField('list_value'):
+ # Clear will mark list_value modified which will indeed create a list.
+ self.fields[key].list_value.Clear()
return self.fields[key].list_value
def get_or_create_struct(self, key):
"""Returns a struct for this key, creating if it didn't exist already."""
+ if not self.fields[key].HasField('struct_value'):
+ # Clear will mark struct_value modified which will indeed create a struct.
+ self.fields[key].struct_value.Clear()
return self.fields[key].struct_value
- # TODO(haberman): allow constructing/merging from dict.
+ def update(self, dictionary): # pylint: disable=invalid-name
+ for key, value in dictionary.items():
+ _SetStructValue(self.fields[key], value)
+
+collections.MutableMapping.register(Struct)
class ListValue(object):
@@ -697,17 +805,28 @@ class ListValue(object):
def __setitem__(self, index, value):
_SetStructValue(self.values.__getitem__(index), value)
+ def __delitem__(self, key):
+ del self.values[key]
+
def items(self):
for i in range(len(self)):
yield self[i]
def add_struct(self):
"""Appends and returns a struct value as the next value in the list."""
- return self.values.add().struct_value
+ struct_value = self.values.add().struct_value
+ # Clear will mark struct_value modified which will indeed create a struct.
+ struct_value.Clear()
+ return struct_value
def add_list(self):
"""Appends and returns a list value as the next value in the list."""
- return self.values.add().list_value
+ list_value = self.values.add().list_value
+ # Clear will mark list_value modified which will indeed create a list.
+ list_value.Clear()
+ return list_value
+
+collections.MutableSequence.register(ListValue)
WKTBASES = {