From 5477f8cdbab0103beba17fc90ae8730835ea427f Mon Sep 17 00:00:00 2001 From: CH Albach Date: Fri, 29 Jan 2016 18:10:50 -0800 Subject: Manually down-integrate python JSON struct support from internal code base. --- .gitignore | 1 + .../google/protobuf/internal/json_format_test.py | 244 ++++++++++++++++++++- .../google/protobuf/internal/well_known_types.py | 102 ++++++++- .../protobuf/internal/well_known_types_test.py | 123 +++++++++++ python/google/protobuf/json_format.py | 214 ++++++++++++++++-- 5 files changed, 659 insertions(+), 25 deletions(-) diff --git a/.gitignore b/.gitignore index 19e52129..a2d6ca9c 100644 --- a/.gitignore +++ b/.gitignore @@ -58,6 +58,7 @@ python/.eggs/ python/.tox python/build/ python/google/protobuf/compiler/ +python/google/protobuf/util/ src/protoc src/unittest_proto_middleman diff --git a/python/google/protobuf/internal/json_format_test.py b/python/google/protobuf/internal/json_format_test.py index be3ad11a..49e96a46 100644 --- a/python/google/protobuf/internal/json_format_test.py +++ b/python/google/protobuf/internal/json_format_test.py @@ -42,6 +42,12 @@ try: import unittest2 as unittest except ImportError: import unittest +from google.protobuf import any_pb2 +from google.protobuf import duration_pb2 +from google.protobuf import field_mask_pb2 +from google.protobuf import struct_pb2 +from google.protobuf import timestamp_pb2 +from google.protobuf import wrappers_pb2 from google.protobuf.internal import well_known_types from google.protobuf import json_format from google.protobuf.util import json_format_proto3_pb2 @@ -326,6 +332,7 @@ class JsonFormatTest(JsonFormatBase): message.bytes_value.value = b'' message.repeated_bool_value.add().value = True message.repeated_bool_value.add().value = False + message.repeated_int32_value.add() self.assertEqual( json.loads(json_format.MessageToJson(message, True)), json.loads('{\n' @@ -334,7 +341,7 @@ class JsonFormatTest(JsonFormatBase): ' "stringValue": "",' ' "bytesValue": "",' ' "repeatedBoolValue": [true, false],' - ' "repeatedInt32Value": [],' + ' "repeatedInt32Value": [0],' ' "repeatedUint32Value": [],' ' "repeatedFloatValue": [],' ' "repeatedDoubleValue": [],' @@ -346,11 +353,192 @@ class JsonFormatTest(JsonFormatBase): parsed_message = json_format_proto3_pb2.TestWrapper() self.CheckParseBack(message, parsed_message) + def testStructMessage(self): + message = json_format_proto3_pb2.TestStruct() + message.value['name'] = 'Jim' + message.value['age'] = 10 + message.value['attend'] = True + message.value['email'] = None + message.value.get_or_create_struct('address')['city'] = 'SFO' + message.value['address']['house_number'] = 1024 + struct_list = message.value.get_or_create_list('list') + struct_list.extend([6, 'seven', True, False, None]) + struct_list.add_struct()['subkey2'] = 9 + message.repeated_value.add()['age'] = 11 + message.repeated_value.add() + self.assertEqual( + json.loads(json_format.MessageToJson(message, False)), + json.loads( + '{' + ' "value": {' + ' "address": {' + ' "city": "SFO", ' + ' "house_number": 1024' + ' }, ' + ' "age": 10, ' + ' "name": "Jim", ' + ' "attend": true, ' + ' "email": null, ' + ' "list": [6, "seven", true, false, null, {"subkey2": 9}]' + ' },' + ' "repeatedValue": [{"age": 11}, {}]' + '}')) + parsed_message = json_format_proto3_pb2.TestStruct() + self.CheckParseBack(message, parsed_message) + + def testValueMessage(self): + message = json_format_proto3_pb2.TestValue() + message.value.string_value = 'hello' + message.repeated_value.add().number_value = 11.1 + message.repeated_value.add().bool_value = False + message.repeated_value.add().null_value = 0 + self.assertEqual( + json.loads(json_format.MessageToJson(message, False)), + json.loads( + '{' + ' "value": "hello",' + ' "repeatedValue": [11.1, false, null]' + '}')) + parsed_message = json_format_proto3_pb2.TestValue() + self.CheckParseBack(message, parsed_message) + # Can't parse back if the Value message is not set. + message.repeated_value.add() + self.assertEqual( + json.loads(json_format.MessageToJson(message, False)), + json.loads( + '{' + ' "value": "hello",' + ' "repeatedValue": [11.1, false, null, null]' + '}')) + + def testListValueMessage(self): + message = json_format_proto3_pb2.TestListValue() + message.value.values.add().number_value = 11.1 + message.value.values.add().null_value = 0 + message.value.values.add().bool_value = True + message.value.values.add().string_value = 'hello' + message.value.values.add().struct_value['name'] = 'Jim' + message.repeated_value.add().values.add().number_value = 1 + message.repeated_value.add() + self.assertEqual( + json.loads(json_format.MessageToJson(message, False)), + json.loads( + '{"value": [11.1, null, true, "hello", {"name": "Jim"}]\n,' + '"repeatedValue": [[1], []]}')) + parsed_message = json_format_proto3_pb2.TestListValue() + self.CheckParseBack(message, parsed_message) + + def testAnyMessage(self): + message = json_format_proto3_pb2.TestAny() + value1 = json_format_proto3_pb2.MessageType() + value2 = json_format_proto3_pb2.MessageType() + value1.value = 1234 + value2.value = 5678 + message.value.Pack(value1) + message.repeated_value.add().Pack(value1) + message.repeated_value.add().Pack(value2) + message.repeated_value.add() + self.assertEqual( + json.loads(json_format.MessageToJson(message, True)), + json.loads( + '{\n' + ' "repeatedValue": [ {\n' + ' "@type": "type.googleapis.com/proto3.MessageType",\n' + ' "value": 1234\n' + ' }, {\n' + ' "@type": "type.googleapis.com/proto3.MessageType",\n' + ' "value": 5678\n' + ' },\n' + ' {}],\n' + ' "value": {\n' + ' "@type": "type.googleapis.com/proto3.MessageType",\n' + ' "value": 1234\n' + ' }\n' + '}\n')) + parsed_message = json_format_proto3_pb2.TestAny() + self.CheckParseBack(message, parsed_message) + + def testWellKnownInAnyMessage(self): + message = any_pb2.Any() + int32_value = wrappers_pb2.Int32Value() + int32_value.value = 1234 + message.Pack(int32_value) + self.assertEqual( + json.loads(json_format.MessageToJson(message, True)), + json.loads( + '{\n' + ' "@type": \"type.googleapis.com/google.protobuf.Int32Value\",\n' + ' "value": 1234\n' + '}\n')) + parsed_message = any_pb2.Any() + self.CheckParseBack(message, parsed_message) + + timestamp = timestamp_pb2.Timestamp() + message.Pack(timestamp) + self.assertEqual( + json.loads(json_format.MessageToJson(message, True)), + json.loads( + '{\n' + ' "@type": "type.googleapis.com/google.protobuf.Timestamp",\n' + ' "value": "1970-01-01T00:00:00Z"\n' + '}\n')) + self.CheckParseBack(message, parsed_message) + + duration = duration_pb2.Duration() + duration.seconds = 1 + message.Pack(duration) + self.assertEqual( + json.loads(json_format.MessageToJson(message, True)), + json.loads( + '{\n' + ' "@type": "type.googleapis.com/google.protobuf.Duration",\n' + ' "value": "1s"\n' + '}\n')) + self.CheckParseBack(message, parsed_message) + + field_mask = field_mask_pb2.FieldMask() + field_mask.paths.append('foo.bar') + field_mask.paths.append('bar') + message.Pack(field_mask) + self.assertEqual( + json.loads(json_format.MessageToJson(message, True)), + json.loads( + '{\n' + ' "@type": "type.googleapis.com/google.protobuf.FieldMask",\n' + ' "value": "foo.bar,bar"\n' + '}\n')) + self.CheckParseBack(message, parsed_message) + + struct_message = struct_pb2.Struct() + struct_message['name'] = 'Jim' + message.Pack(struct_message) + self.assertEqual( + json.loads(json_format.MessageToJson(message, True)), + json.loads( + '{\n' + ' "@type": "type.googleapis.com/google.protobuf.Struct",\n' + ' "value": {"name": "Jim"}\n' + '}\n')) + self.CheckParseBack(message, parsed_message) + + nested_any = any_pb2.Any() + int32_value.value = 5678 + nested_any.Pack(int32_value) + message.Pack(nested_any) + self.assertEqual( + json.loads(json_format.MessageToJson(message, True)), + json.loads( + '{\n' + ' "@type": "type.googleapis.com/google.protobuf.Any",\n' + ' "value": {\n' + ' "@type": "type.googleapis.com/google.protobuf.Int32Value",\n' + ' "value": 5678\n' + ' }\n' + '}\n')) + self.CheckParseBack(message, parsed_message) + def testParseNull(self): message = json_format_proto3_pb2.TestMessage() - message.repeated_int32_value.append(1) - message.repeated_int32_value.append(2) - message.repeated_int32_value.append(3) parsed_message = json_format_proto3_pb2.TestMessage() self.FillAllFields(parsed_message) json_format.Parse('{"int32Value": null, ' @@ -364,7 +552,7 @@ class JsonFormatTest(JsonFormatBase): '"bytesValue": null,' '"messageValue": null,' '"enumValue": null,' - '"repeatedInt32Value": [1, 2, null, 3],' + '"repeatedInt32Value": null,' '"repeatedInt64Value": null,' '"repeatedUint32Value": null,' '"repeatedUint64Value": null,' @@ -378,6 +566,13 @@ class JsonFormatTest(JsonFormatBase): '}', parsed_message) self.assertEqual(message, parsed_message) + self.assertRaisesRegexp( + json_format.ParseError, + 'Failed to parse repeatedInt32Value field: ' + 'null is not allowed to be used as an element in a repeated field.', + json_format.Parse, + '{"repeatedInt32Value":[1, null]}', + parsed_message) def testNanFloat(self): message = json_format_proto3_pb2.TestMessage() @@ -529,6 +724,45 @@ class JsonFormatTest(JsonFormatBase): ' should not have multiple "oneof_value" oneof fields.', json_format.Parse, text, message) + def testInvalidListValue(self): + message = json_format_proto3_pb2.TestListValue() + text = '{"value": 1234}' + self.assertRaisesRegexp( + json_format.ParseError, + r'Failed to parse value field: ListValue must be in \[\] which is 1234', + json_format.Parse, text, message) + + def testInvalidStruct(self): + message = json_format_proto3_pb2.TestStruct() + text = '{"value": 1234}' + self.assertRaisesRegexp( + json_format.ParseError, + 'Failed to parse value field: Struct must be in a dict which is 1234', + json_format.Parse, text, message) + + def testInvalidAny(self): + message = any_pb2.Any() + text = '{"@type": "type.googleapis.com/google.protobuf.Int32Value"}' + self.assertRaisesRegexp( + KeyError, + 'value', + json_format.Parse, text, message) + text = '{"value": 1234}' + self.assertRaisesRegexp( + json_format.ParseError, + '@type is missing when parsing any message.', + json_format.Parse, text, message) + text = '{"@type": "type.googleapis.com/MessageNotExist", "value": 1234}' + self.assertRaisesRegexp( + TypeError, + 'Can not find message descriptor by type_url: ' + 'type.googleapis.com/MessageNotExist.', + json_format.Parse, text, message) + # Only last part is to be used. + text = (r'{"@type": "incorrect.googleapis.com/google.protobuf.Int32Value",' + r'"value": 1234}') + json_format.Parse(text, message) + if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/internal/well_known_types.py b/python/google/protobuf/internal/well_known_types.py index d3de9831..d35fcc5f 100644 --- a/python/google/protobuf/internal/well_known_types.py +++ b/python/google/protobuf/internal/well_known_types.py @@ -34,6 +34,7 @@ This files defines well known classes which need extra maintenance including: - Any - Duration - FieldMask + - Struct - Timestamp """ @@ -41,6 +42,7 @@ __author__ = 'jieluo@google.com (Jie Luo)' from datetime import datetime from datetime import timedelta +import six from google.protobuf.descriptor import FieldDescriptor @@ -64,9 +66,12 @@ class ParseError(Error): class Any(object): """Class for Any Message type.""" - def Pack(self, msg): + def Pack(self, msg, type_url_prefix='type.googleapis.com/'): """Packs the specified message into current Any message.""" - self.type_url = 'type.googleapis.com/%s' % msg.DESCRIPTOR.full_name + if len(type_url_prefix) < 1 or type_url_prefix[-1] != '/': + self.type_url = '%s/%s' % (type_url_prefix, msg.DESCRIPTOR.full_name) + else: + self.type_url = '%s%s' % (type_url_prefix, msg.DESCRIPTOR.full_name) self.value = msg.SerializeToString() def Unpack(self, msg): @@ -614,9 +619,102 @@ def _AddFieldPaths(node, prefix, field_mask): _AddFieldPaths(node[name], child_path, field_mask) +_INT_OR_FLOAT = six.integer_types + (float,) + + +def _SetStructValue(struct_value, value): + if value is None: + struct_value.null_value = 0 + elif isinstance(value, bool): + # Note: this check must come before the number check because in Python + # True and False are also considered numbers. + struct_value.bool_value = value + elif isinstance(value, six.string_types): + struct_value.string_value = value + elif isinstance(value, _INT_OR_FLOAT): + struct_value.number_value = value + else: + raise ValueError('Unexpected type') + + +def _GetStructValue(struct_value): + which = struct_value.WhichOneof('kind') + if which == 'struct_value': + return struct_value.struct_value + elif which == 'null_value': + return None + elif which == 'number_value': + return struct_value.number_value + elif which == 'string_value': + return struct_value.string_value + elif which == 'bool_value': + return struct_value.bool_value + elif which == 'list_value': + return struct_value.list_value + elif which is None: + raise ValueError('Value not set') + + +class Struct(object): + """Class for Struct message type.""" + + __slots__ = [] + + def __getitem__(self, key): + return _GetStructValue(self.fields[key]) + + def __setitem__(self, key, value): + _SetStructValue(self.fields[key], value) + + def get_or_create_list(self, key): + """Returns a list for this key, creating if it didn't exist already.""" + return self.fields[key].list_value + + def get_or_create_struct(self, key): + """Returns a struct for this key, creating if it didn't exist already.""" + return self.fields[key].struct_value + + # TODO(haberman): allow constructing/merging from dict. + + +class ListValue(object): + """Class for ListValue message type.""" + + def __len__(self): + return len(self.values) + + def append(self, value): + _SetStructValue(self.values.add(), value) + + def extend(self, elem_seq): + for value in elem_seq: + self.append(value) + + def __getitem__(self, index): + """Retrieves item by the specified index.""" + return _GetStructValue(self.values.__getitem__(index)) + + def __setitem__(self, index, value): + _SetStructValue(self.values.__getitem__(index), value) + + def items(self): + for i in range(len(self)): + yield self[i] + + def add_struct(self): + """Appends and returns a struct value as the next value in the list.""" + return self.values.add().struct_value + + def add_list(self): + """Appends and returns a list value as the next value in the list.""" + return self.values.add().list_value + + WKTBASES = { 'google.protobuf.Any': Any, 'google.protobuf.Duration': Duration, 'google.protobuf.FieldMask': FieldMask, + 'google.protobuf.ListValue': ListValue, + 'google.protobuf.Struct': Struct, 'google.protobuf.Timestamp': Timestamp, } diff --git a/python/google/protobuf/internal/well_known_types_test.py b/python/google/protobuf/internal/well_known_types_test.py index 0e31e6f8..6acbee22 100644 --- a/python/google/protobuf/internal/well_known_types_test.py +++ b/python/google/protobuf/internal/well_known_types_test.py @@ -41,13 +41,17 @@ try: except ImportError: import unittest +from google.protobuf import any_pb2 from google.protobuf import duration_pb2 from google.protobuf import field_mask_pb2 +from google.protobuf import struct_pb2 from google.protobuf import timestamp_pb2 from google.protobuf import unittest_pb2 +from google.protobuf.internal import any_test_pb2 from google.protobuf.internal import test_util from google.protobuf.internal import well_known_types from google.protobuf import descriptor +from google.protobuf import text_format class TimeUtilTestBase(unittest.TestCase): @@ -509,5 +513,124 @@ class FieldMaskTest(unittest.TestCase): self.assertEqual(1, len(nested_dst.payload.repeated_int32)) self.assertEqual(1234, nested_dst.payload.repeated_int32[0]) + +class StructTest(unittest.TestCase): + + def testStruct(self): + struct = struct_pb2.Struct() + struct_class = struct.__class__ + + struct['key1'] = 5 + struct['key2'] = 'abc' + struct['key3'] = True + struct.get_or_create_struct('key4')['subkey'] = 11.0 + struct_list = struct.get_or_create_list('key5') + struct_list.extend([6, 'seven', True, False, None]) + struct_list.add_struct()['subkey2'] = 9 + + self.assertTrue(isinstance(struct, well_known_types.Struct)) + self.assertEquals(5, struct['key1']) + self.assertEquals('abc', struct['key2']) + self.assertIs(True, struct['key3']) + self.assertEquals(11, struct['key4']['subkey']) + inner_struct = struct_class() + inner_struct['subkey2'] = 9 + self.assertEquals([6, 'seven', True, False, None, inner_struct], + list(struct['key5'].items())) + + serialized = struct.SerializeToString() + + struct2 = struct_pb2.Struct() + struct2.ParseFromString(serialized) + + self.assertEquals(struct, struct2) + + self.assertTrue(isinstance(struct2, well_known_types.Struct)) + self.assertEquals(5, struct2['key1']) + self.assertEquals('abc', struct2['key2']) + self.assertIs(True, struct2['key3']) + self.assertEquals(11, struct2['key4']['subkey']) + self.assertEquals([6, 'seven', True, False, None, inner_struct], + list(struct2['key5'].items())) + + struct_list = struct2['key5'] + self.assertEquals(6, struct_list[0]) + self.assertEquals('seven', struct_list[1]) + self.assertEquals(True, struct_list[2]) + self.assertEquals(False, struct_list[3]) + self.assertEquals(None, struct_list[4]) + self.assertEquals(inner_struct, struct_list[5]) + + struct_list[1] = 7 + self.assertEquals(7, struct_list[1]) + + struct_list.add_list().extend([1, 'two', True, False, None]) + self.assertEquals([1, 'two', True, False, None], + list(struct_list[6].items())) + + text_serialized = str(struct) + struct3 = struct_pb2.Struct() + text_format.Merge(text_serialized, struct3) + self.assertEquals(struct, struct3) + + struct.get_or_create_struct('key3')['replace'] = 12 + self.assertEquals(12, struct['key3']['replace']) + + +class AnyTest(unittest.TestCase): + + def testAnyMessage(self): + # Creates and sets message. + msg = any_test_pb2.TestAny() + msg_descriptor = msg.DESCRIPTOR + all_types = unittest_pb2.TestAllTypes() + all_descriptor = all_types.DESCRIPTOR + all_types.repeated_string.append(u'\u00fc\ua71f') + # Packs to Any. + msg.value.Pack(all_types) + self.assertEqual(msg.value.type_url, + 'type.googleapis.com/%s' % all_descriptor.full_name) + self.assertEqual(msg.value.value, + all_types.SerializeToString()) + # Tests Is() method. + self.assertTrue(msg.value.Is(all_descriptor)) + self.assertFalse(msg.value.Is(msg_descriptor)) + # Unpacks Any. + unpacked_message = unittest_pb2.TestAllTypes() + self.assertTrue(msg.value.Unpack(unpacked_message)) + self.assertEqual(all_types, unpacked_message) + # Unpacks to different type. + self.assertFalse(msg.value.Unpack(msg)) + # Only Any messages have Pack method. + try: + msg.Pack(all_types) + except AttributeError: + pass + else: + raise AttributeError('%s should not have Pack method.' % + msg_descriptor.full_name) + + def testPackWithCustomTypeUrl(self): + submessage = any_test_pb2.TestAny() + submessage.int_value = 12345 + msg = any_pb2.Any() + # Pack with a custom type URL prefix. + msg.Pack(submessage, 'type.myservice.com') + self.assertEqual(msg.type_url, + 'type.myservice.com/%s' % submessage.DESCRIPTOR.full_name) + # Pack with a custom type URL prefix ending with '/'. + msg.Pack(submessage, 'type.myservice.com/') + self.assertEqual(msg.type_url, + 'type.myservice.com/%s' % submessage.DESCRIPTOR.full_name) + # Pack with an empty type URL prefix. + msg.Pack(submessage, '') + self.assertEqual(msg.type_url, + '/%s' % submessage.DESCRIPTOR.full_name) + # Test unpacking the type. + unpacked_message = any_test_pb2.TestAny() + self.assertTrue(msg.Unpack(unpacked_message)) + self.assertEqual(submessage, unpacked_message) + + if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/json_format.py b/python/google/protobuf/json_format.py index cb76e116..23382bdb 100644 --- a/python/google/protobuf/json_format.py +++ b/python/google/protobuf/json_format.py @@ -45,10 +45,11 @@ __author__ = 'jieluo@google.com (Jie Luo)' import base64 import json import math -from six import text_type +import six import sys from google.protobuf import descriptor +from google.protobuf import symbol_database _TIMESTAMPFOMAT = '%Y-%m-%dT%H:%M:%S' _INT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_INT32, @@ -96,11 +97,15 @@ def MessageToJson(message, including_default_value_fields=False): def _MessageToJsonObject(message, including_default_value_fields): """Converts message to an object according to Proto3 JSON Specification.""" message_descriptor = message.DESCRIPTOR - if hasattr(message, 'ToJsonString'): - return message.ToJsonString() + full_name = message_descriptor.full_name if _IsWrapperMessage(message_descriptor): return _WrapperMessageToJsonObject(message) - return _RegularMessageToJsonObject(message, including_default_value_fields) + if full_name in _WKTJSONMETHODS: + return _WKTJSONMETHODS[full_name][0]( + message, including_default_value_fields) + js = {} + return _RegularMessageToJsonObject( + message, js, including_default_value_fields) def _IsMapEntry(field): @@ -109,9 +114,8 @@ def _IsMapEntry(field): field.message_type.GetOptions().map_entry) -def _RegularMessageToJsonObject(message, including_default_value_fields): +def _RegularMessageToJsonObject(message, js, including_default_value_fields): """Converts normal message according to Proto3 JSON Specification.""" - js = {} fields = message.ListFields() include_default = including_default_value_fields @@ -200,6 +204,79 @@ def _FieldToJsonObject( return value +def _AnyMessageToJsonObject(message, including_default): + """Converts Any message according to Proto3 JSON Specification.""" + if not message.ListFields(): + return {} + js = {} + type_url = message.type_url + js['@type'] = type_url + sub_message = _CreateMessageFromTypeUrl(type_url) + sub_message.ParseFromString(message.value) + message_descriptor = sub_message.DESCRIPTOR + full_name = message_descriptor.full_name + if _IsWrapperMessage(message_descriptor): + js['value'] = _WrapperMessageToJsonObject(sub_message) + return js + if full_name in _WKTJSONMETHODS: + js['value'] = _WKTJSONMETHODS[full_name][0](sub_message, including_default) + return js + return _RegularMessageToJsonObject(sub_message, js, including_default) + + +def _CreateMessageFromTypeUrl(type_url): + # TODO(jieluo): Should add a way that users can register the type resolver + # instead of the default one. + db = symbol_database.Default() + type_name = type_url.split('/')[-1] + try: + message_descriptor = db.pool.FindMessageTypeByName(type_name) + except KeyError: + raise TypeError( + 'Can not find message descriptor by type_url: {0}.'.format(type_url)) + message_class = db.GetPrototype(message_descriptor) + return message_class() + + +def _GenericMessageToJsonObject(message, unused_including_default): + """Converts message by ToJsonString according to Proto3 JSON Specification.""" + # Duration, Timestamp and FieldMask have ToJsonString method to do the + # convert. Users can also call the method directly. + return message.ToJsonString() + + +def _ValueMessageToJsonObject(message, unused_including_default=False): + """Converts Value message according to Proto3 JSON Specification.""" + which = message.WhichOneof('kind') + # If the Value message is not set treat as null_value when serialize + # to JSON. The parse back result will be different from original message. + if which is None or which == 'null_value': + return None + if which == 'list_value': + return _ListValueMessageToJsonObject(message.list_value) + if which == 'struct_value': + value = message.struct_value + else: + value = getattr(message, which) + oneof_descriptor = message.DESCRIPTOR.fields_by_name[which] + return _FieldToJsonObject(oneof_descriptor, value) + + +def _ListValueMessageToJsonObject(message, unused_including_default=False): + """Converts ListValue message according to Proto3 JSON Specification.""" + return [_ValueMessageToJsonObject(value) + for value in message.values] + + +def _StructMessageToJsonObject(message, unused_including_default=False): + """Converts Struct message according to Proto3 JSON Specification.""" + fields = message.fields + js = {} + for key in fields.keys(): + js[key] = _ValueMessageToJsonObject(fields[key]) + return js + + def _IsWrapperMessage(message_descriptor): return message_descriptor.file.name == 'google/protobuf/wrappers.proto' @@ -231,7 +308,7 @@ def Parse(text, message): Raises:: ParseError: On JSON parsing problems. """ - if not isinstance(text, text_type): text = text.decode('utf-8') + if not isinstance(text, six.text_type): text = text.decode('utf-8') try: if sys.version_info < (2, 7): # object_pair_hook is not supported before python2.7 @@ -240,7 +317,7 @@ def Parse(text, message): js = json.loads(text, object_pairs_hook=_DuplicateChecker) except ValueError as e: raise ParseError('Failed to load JSON: {0}.'.format(str(e))) - _ConvertFieldValuePair(js, message) + _ConvertMessage(js, message) return message @@ -291,13 +368,22 @@ def _ConvertFieldValuePair(js, message): if not isinstance(value, list): raise ParseError('repeated field {0} must be in [] which is ' '{1}.'.format(name, value)) - for item in value: - if item is None: - continue - if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + # Repeated message field. + for item in value: sub_message = getattr(message, field.name).add() + # None is a null_value in Value. + if (item is None and + sub_message.DESCRIPTOR.full_name != 'google.protobuf.Value'): + raise ParseError('null is not allowed to be used as an element' + ' in a repeated field.') _ConvertMessage(item, sub_message) - else: + else: + # Repeated scalar field. + for item in value: + if item is None: + raise ParseError('null is not allowed to be used as an element' + ' in a repeated field.') getattr(message, field.name).append( _ConvertScalarFieldValue(item, field)) elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: @@ -327,13 +413,87 @@ def _ConvertMessage(value, message): ParseError: In case of convert problems. """ message_descriptor = message.DESCRIPTOR - if hasattr(message, 'FromJsonString'): - message.FromJsonString(value) - elif _IsWrapperMessage(message_descriptor): + full_name = message_descriptor.full_name + if _IsWrapperMessage(message_descriptor): _ConvertWrapperMessage(value, message) + elif full_name in _WKTJSONMETHODS: + _WKTJSONMETHODS[full_name][1](value, message) else: _ConvertFieldValuePair(value, message) + +def _ConvertAnyMessage(value, message): + """Convert a JSON representation into Any message.""" + if isinstance(value, dict) and not value: + return + try: + type_url = value['@type'] + except KeyError: + raise ParseError('@type is missing when parsing any message.') + + sub_message = _CreateMessageFromTypeUrl(type_url) + message_descriptor = sub_message.DESCRIPTOR + full_name = message_descriptor.full_name + if _IsWrapperMessage(message_descriptor): + _ConvertWrapperMessage(value['value'], sub_message) + elif full_name in _WKTJSONMETHODS: + _WKTJSONMETHODS[full_name][1](value['value'], sub_message) + else: + del value['@type'] + _ConvertFieldValuePair(value, sub_message) + # Sets Any message + message.value = sub_message.SerializeToString() + message.type_url = type_url + + +def _ConvertGenericMessage(value, message): + """Convert a JSON representation into message with FromJsonString.""" + # Durantion, Timestamp, FieldMask have FromJsonString method to do the + # convert. Users can also call the method directly. + message.FromJsonString(value) + + +_INT_OR_FLOAT = six.integer_types + (float,) + + +def _ConvertValueMessage(value, message): + """Convert a JSON representation into Value message.""" + if isinstance(value, dict): + _ConvertStructMessage(value, message.struct_value) + elif isinstance(value, list): + _ConvertListValueMessage(value, message.list_value) + elif value is None: + message.null_value = 0 + elif isinstance(value, bool): + message.bool_value = value + elif isinstance(value, six.string_types): + message.string_value = value + elif isinstance(value, _INT_OR_FLOAT): + message.number_value = value + else: + raise ParseError('Unexpected type for Value message.') + + +def _ConvertListValueMessage(value, message): + """Convert a JSON representation into ListValue message.""" + if not isinstance(value, list): + raise ParseError( + 'ListValue must be in [] which is {0}.'.format(value)) + message.ClearField('values') + for item in value: + _ConvertValueMessage(item, message.values.add()) + + +def _ConvertStructMessage(value, message): + """Convert a JSON representation into Struct message.""" + if not isinstance(value, dict): + raise ParseError( + 'Struct must be in a dict which is {0}.'.format(value)) + for key in value: + _ConvertValueMessage(value[key], message.fields[key]) + return + + def _ConvertWrapperMessage(value, message): """Convert a JSON representation into Wrapper message.""" field = message.DESCRIPTOR.fields_by_name['value'] @@ -353,7 +513,8 @@ def _ConvertMapFieldValue(value, message, field): """ if not isinstance(value, dict): raise ParseError( - 'Map fieled {0} must be in {} which is {1}.'.format(field.name, value)) + 'Map field {0} must be in a dict which is {1}.'.format( + field.name, value)) key_field = field.message_type.fields_by_name['key'] value_field = field.message_type.fields_by_name['value'] for key in value: @@ -416,7 +577,7 @@ def _ConvertInteger(value): if isinstance(value, float): raise ParseError('Couldn\'t parse integer: {0}.'.format(value)) - if isinstance(value, text_type) and value.find(' ') != -1: + if isinstance(value, six.text_type) and value.find(' ') != -1: raise ParseError('Couldn\'t parse integer: "{0}".'.format(value)) return int(value) @@ -465,3 +626,20 @@ def _ConvertBool(value, require_str): if not isinstance(value, bool): raise ParseError('Expected true or false without quotes.') return value + +_WKTJSONMETHODS = { + 'google.protobuf.Any': [_AnyMessageToJsonObject, + _ConvertAnyMessage], + 'google.protobuf.Duration': [_GenericMessageToJsonObject, + _ConvertGenericMessage], + 'google.protobuf.FieldMask': [_GenericMessageToJsonObject, + _ConvertGenericMessage], + 'google.protobuf.ListValue': [_ListValueMessageToJsonObject, + _ConvertListValueMessage], + 'google.protobuf.Struct': [_StructMessageToJsonObject, + _ConvertStructMessage], + 'google.protobuf.Timestamp': [_GenericMessageToJsonObject, + _ConvertGenericMessage], + 'google.protobuf.Value': [_ValueMessageToJsonObject, + _ConvertValueMessage] +} -- cgit v1.2.3