diff options
author | Feng Xiao <xfxyjwf@gmail.com> | 2016-07-13 13:47:51 -0700 |
---|---|---|
committer | Feng Xiao <xfxyjwf@gmail.com> | 2016-07-13 13:48:40 -0700 |
commit | 9086d9643903c608ab015b0b7d903547a4e7b6f3 (patch) | |
tree | b47053ab6f6bde20b55c4fff4019c68a7c45545c /python/google | |
parent | 70c1ac756d3cd8fa04725f82f0ad1a30404c3bb3 (diff) |
Integrate from internal code base.
Diffstat (limited to 'python/google')
-rw-r--r-- | python/google/protobuf/internal/json_format_test.py | 6 | ||||
-rwxr-xr-x | python/google/protobuf/internal/python_message.py | 26 | ||||
-rwxr-xr-x | python/google/protobuf/internal/reflection_test.py | 1 | ||||
-rw-r--r-- | python/google/protobuf/internal/symbol_database_test.py | 26 | ||||
-rw-r--r-- | python/google/protobuf/pyext/cpp_message.py | 6 | ||||
-rw-r--r-- | python/google/protobuf/pyext/map_container.cc | 3 | ||||
-rw-r--r-- | python/google/protobuf/pyext/message.cc | 16 | ||||
-rw-r--r-- | python/google/protobuf/pyext/message.h | 4 | ||||
-rw-r--r-- | python/google/protobuf/pyext/repeated_composite_container.cc | 4 | ||||
-rwxr-xr-x | python/google/protobuf/reflection.py | 8 | ||||
-rw-r--r-- | python/google/protobuf/symbol_database.py | 82 |
11 files changed, 82 insertions, 100 deletions
diff --git a/python/google/protobuf/internal/json_format_test.py b/python/google/protobuf/internal/json_format_test.py index 6df12bea..a5ee8ace 100644 --- a/python/google/protobuf/internal/json_format_test.py +++ b/python/google/protobuf/internal/json_format_test.py @@ -252,10 +252,7 @@ class JsonFormatTest(JsonFormatBase): message = json_format_proto3_pb2.TestMessage() json_format.Parse('{"stringValue": "\\uD83D\\uDE01"}', message) self.assertEqual(message.string_value, - b'\xF0\x9F\x98\x81'.decode("utf-8", "strict")) - - # TODO: add test that UTF-8 encoded surrogate code points are rejected. - # UTF-8 does not allow them. + b'\xF0\x9F\x98\x81'.decode('utf-8', 'strict')) # Error case: unpaired high surrogate. self.CheckError( @@ -267,7 +264,6 @@ class JsonFormatTest(JsonFormatBase): '{"stringValue": "\\uDE01"}', r'Invalid \\uXXXX escape|Unpaired.*surrogate') - def testTimestampMessage(self): message = json_format_proto3_pb2.TestTimestamp() message.value.seconds = 0 diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index f8f73dd2..c0d0ad45 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -76,7 +76,6 @@ from google.protobuf.internal import well_known_types from google.protobuf.internal import wire_format from google.protobuf import descriptor as descriptor_mod from google.protobuf import message as message_mod -from google.protobuf import symbol_database from google.protobuf import text_format _FieldDescriptor = descriptor_mod.FieldDescriptor @@ -98,16 +97,12 @@ class GeneratedProtocolMessageType(type): classes at runtime, as in this example: mydescriptor = Descriptor(.....) - class MyProtoClass(Message): - __metaclass__ = GeneratedProtocolMessageType - DESCRIPTOR = mydescriptor + factory = symbol_database.Default() + factory.pool.AddDescriptor(mydescriptor) + MyProtoClass = factory.GetPrototype(mydescriptor) myproto_instance = MyProtoClass() myproto.foo_field = 23 ... - - The above example will not work for nested types. If you wish to include them, - use reflection.MakeClass() instead of manually instantiating the class in - order to create the appropriate class structure. """ # Must be consistent with the protocol-compiler code in @@ -926,26 +921,33 @@ def _InternalUnpackAny(msg): Returns: The unpacked message. """ + # TODO(amauryfa): Don't use the factory of generated messages. + # To make Any work with custom factories, use the message factory of the + # parent message. + # pylint: disable=g-import-not-at-top + from google.protobuf import symbol_database + factory = symbol_database.Default() + type_url = msg.type_url - db = symbol_database.Default() if not type_url: return None # TODO(haberman): For now we just strip the hostname. Better logic will be # required. - type_name = type_url.split("/")[-1] - descriptor = db.pool.FindMessageTypeByName(type_name) + type_name = type_url.split('/')[-1] + descriptor = factory.pool.FindMessageTypeByName(type_name) if descriptor is None: return None - message_class = db.GetPrototype(descriptor) + message_class = factory.GetPrototype(descriptor) message = message_class() message.ParseFromString(msg.value) return message + def _AddEqualsMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" def __eq__(self, other): diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py index 6dc2fffe..20e5d245 100755 --- a/python/google/protobuf/internal/reflection_test.py +++ b/python/google/protobuf/internal/reflection_test.py @@ -972,6 +972,7 @@ class ReflectionTest(unittest.TestCase): proto.repeated_nested_message.add(bb=23) self.assertEqual(1, len(proto.repeated_nested_message)) self.assertEqual(23, proto.repeated_nested_message[0].bb) + self.assertRaises(TypeError, proto.repeated_nested_message.add, 23) def testRepeatedCompositeRemove(self): proto = unittest_pb2.TestAllTypes() diff --git a/python/google/protobuf/internal/symbol_database_test.py b/python/google/protobuf/internal/symbol_database_test.py index c99b426d..4f5173b2 100644 --- a/python/google/protobuf/internal/symbol_database_test.py +++ b/python/google/protobuf/internal/symbol_database_test.py @@ -39,26 +39,28 @@ except ImportError: from google.protobuf import unittest_pb2 from google.protobuf import descriptor +from google.protobuf import descriptor_pool from google.protobuf import symbol_database + class SymbolDatabaseTest(unittest.TestCase): def _Database(self): - # TODO(b/17734095): Remove this difference when the C++ implementation - # supports multiple databases. if descriptor._USE_C_DESCRIPTORS: - return symbol_database.Default() + # The C++ implementation does not allow mixing descriptors from + # different pools. + db = symbol_database.SymbolDatabase(pool=descriptor_pool.Default()) else: db = symbol_database.SymbolDatabase() - # Register representative types from unittest_pb2. - db.RegisterFileDescriptor(unittest_pb2.DESCRIPTOR) - db.RegisterMessage(unittest_pb2.TestAllTypes) - db.RegisterMessage(unittest_pb2.TestAllTypes.NestedMessage) - db.RegisterMessage(unittest_pb2.TestAllTypes.OptionalGroup) - db.RegisterMessage(unittest_pb2.TestAllTypes.RepeatedGroup) - db.RegisterEnumDescriptor(unittest_pb2.ForeignEnum.DESCRIPTOR) - db.RegisterEnumDescriptor(unittest_pb2.TestAllTypes.NestedEnum.DESCRIPTOR) - return db + # Register representative types from unittest_pb2. + db.RegisterFileDescriptor(unittest_pb2.DESCRIPTOR) + db.RegisterMessage(unittest_pb2.TestAllTypes) + db.RegisterMessage(unittest_pb2.TestAllTypes.NestedMessage) + db.RegisterMessage(unittest_pb2.TestAllTypes.OptionalGroup) + db.RegisterMessage(unittest_pb2.TestAllTypes.RepeatedGroup) + db.RegisterEnumDescriptor(unittest_pb2.ForeignEnum.DESCRIPTOR) + db.RegisterEnumDescriptor(unittest_pb2.TestAllTypes.NestedEnum.DESCRIPTOR) + return db def testGetPrototype(self): instance = self._Database().GetPrototype( diff --git a/python/google/protobuf/pyext/cpp_message.py b/python/google/protobuf/pyext/cpp_message.py index b215211e..fc8eb32d 100644 --- a/python/google/protobuf/pyext/cpp_message.py +++ b/python/google/protobuf/pyext/cpp_message.py @@ -48,9 +48,9 @@ class GeneratedProtocolMessageType(_message.MessageMeta): classes at runtime, as in this example: mydescriptor = Descriptor(.....) - class MyProtoClass(Message): - __metaclass__ = GeneratedProtocolMessageType - DESCRIPTOR = mydescriptor + factory = symbol_database.Default() + factory.pool.AddDescriptor(mydescriptor) + MyProtoClass = factory.GetPrototype(mydescriptor) myproto_instance = MyProtoClass() myproto.foo_field = 23 ... diff --git a/python/google/protobuf/pyext/map_container.cc b/python/google/protobuf/pyext/map_container.cc index 90438df1..0987b898 100644 --- a/python/google/protobuf/pyext/map_container.cc +++ b/python/google/protobuf/pyext/map_container.cc @@ -348,9 +348,10 @@ PyObject* MapReflectionFriend::Contains(PyObject* _self, PyObject* key) { } // Initializes the underlying Message object of "to" so it becomes a new parent -// repeated scalar, and copies all the values from "from" to it. A child scalar +// map container, and copies all the values from "from" to it. A child map // container can be released by passing it as both from and to (e.g. making it // the recipient of the new parent message and copying the values from itself). +// In fact, this is the only supported use at the moment. static int InitializeAndCopyToParentContainer(MapContainer* from, MapContainer* to) { // For now we require from == to, re-evaluate if we want to support deep copy diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc index a9261f20..5535338d 100644 --- a/python/google/protobuf/pyext/message.cc +++ b/python/google/protobuf/pyext/message.cc @@ -1041,7 +1041,12 @@ int InternalDeleteRepeatedField( } // Initializes fields of a message. Used in constructors. -int InitAttributes(CMessage* self, PyObject* kwargs) { +int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) { + if (args != NULL && PyTuple_Size(args) != 0) { + PyErr_SetString(PyExc_TypeError, "No positional arguments allowed"); + return -1; + } + if (kwargs == NULL) { return 0; } @@ -1167,7 +1172,7 @@ int InitAttributes(CMessage* self, PyObject* kwargs) { } CMessage* cmessage = reinterpret_cast<CMessage*>(message.get()); if (PyDict_Check(value)) { - if (InitAttributes(cmessage, value) < 0) { + if (InitAttributes(cmessage, NULL, value) < 0) { return -1; } } else { @@ -1245,12 +1250,7 @@ static PyObject* New(PyTypeObject* cls, // The __init__ method of Message classes. // It initializes fields from keywords passed to the constructor. static int Init(CMessage* self, PyObject* args, PyObject* kwargs) { - if (PyTuple_Size(args) != 0) { - PyErr_SetString(PyExc_TypeError, "No positional arguments allowed"); - return -1; - } - - return InitAttributes(self, kwargs); + return InitAttributes(self, args, kwargs); } // --------------------------------------------------------------------- diff --git a/python/google/protobuf/pyext/message.h b/python/google/protobuf/pyext/message.h index 8b399e05..c44a2ae2 100644 --- a/python/google/protobuf/pyext/message.h +++ b/python/google/protobuf/pyext/message.h @@ -237,7 +237,9 @@ PyObject* HasFieldByDescriptor( PyObject* HasField(CMessage* self, PyObject* arg); // Initializes values of fields on a newly constructed message. -int InitAttributes(CMessage* self, PyObject* kwargs); +// Note that positional arguments are disallowed: 'args' must be NULL or the +// empty tuple. +int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs); PyObject* MergeFrom(CMessage* self, PyObject* arg); diff --git a/python/google/protobuf/pyext/repeated_composite_container.cc b/python/google/protobuf/pyext/repeated_composite_container.cc index 4f339e77..bb2f6db2 100644 --- a/python/google/protobuf/pyext/repeated_composite_container.cc +++ b/python/google/protobuf/pyext/repeated_composite_container.cc @@ -146,7 +146,7 @@ static PyObject* AddToAttached(RepeatedCompositeContainer* self, cmsg->owner = self->owner; cmsg->message = sub_message; cmsg->parent = self->parent; - if (cmessage::InitAttributes(cmsg, kwargs) < 0) { + if (cmessage::InitAttributes(cmsg, args, kwargs) < 0) { Py_DECREF(cmsg); return NULL; } @@ -166,7 +166,7 @@ static PyObject* AddToReleased(RepeatedCompositeContainer* self, // Create a new Message detached from the rest. PyObject* py_cmsg = PyEval_CallObjectWithKeywords( - self->child_message_class->AsPyObject(), NULL, kwargs); + self->child_message_class->AsPyObject(), args, kwargs); if (py_cmsg == NULL) return NULL; diff --git a/python/google/protobuf/reflection.py b/python/google/protobuf/reflection.py index 0c757264..51c83321 100755 --- a/python/google/protobuf/reflection.py +++ b/python/google/protobuf/reflection.py @@ -58,13 +58,7 @@ else: from google.protobuf.internal import python_message as message_impl # The type of all Message classes. -# Part of the public interface. -# -# Used by generated files, but clients can also use it at runtime: -# mydescriptor = pool.FindDescriptor(.....) -# class MyProtoClass(Message): -# __metaclass__ = GeneratedProtocolMessageType -# DESCRIPTOR = mydescriptor +# Part of the public interface, but normally only used by message factories. GeneratedProtocolMessageType = message_impl.GeneratedProtocolMessageType diff --git a/python/google/protobuf/symbol_database.py b/python/google/protobuf/symbol_database.py index 87760f26..aa466abd 100644 --- a/python/google/protobuf/symbol_database.py +++ b/python/google/protobuf/symbol_database.py @@ -30,11 +30,9 @@ """A database of Python protocol buffer generated symbols. -SymbolDatabase makes it easy to create new instances of a registered type, given -only the type's protocol buffer symbol name. Once all symbols are registered, -they can be accessed using either the MessageFactory interface which -SymbolDatabase exposes, or the DescriptorPool interface of the underlying -pool. +SymbolDatabase is the MessageFactory for messages generated at compile time, +and makes it easy to create new instances of a registered type, given only the +type's protocol buffer symbol name. Example usage: @@ -61,27 +59,17 @@ Example usage: from google.protobuf import descriptor_pool +from google.protobuf import message_factory -class SymbolDatabase(object): - """A database of Python generated symbols. - - SymbolDatabase also models message_factory.MessageFactory. - - The symbol database can be used to keep a global registry of all protocol - buffer types used within a program. - """ - - def __init__(self, pool=None): - """Constructor.""" - - self._symbols = {} - self._symbols_by_file = {} - self.pool = pool or descriptor_pool.Default() +class SymbolDatabase(message_factory.MessageFactory): + """A database of Python generated symbols.""" def RegisterMessage(self, message): """Registers the given message type in the local database. + Calls to GetSymbol() and GetMessages() will return messages registered here. + Args: message: a message.Message, to be registered. @@ -90,10 +78,7 @@ class SymbolDatabase(object): """ desc = message.DESCRIPTOR - self._symbols[desc.full_name] = message - if desc.file.name not in self._symbols_by_file: - self._symbols_by_file[desc.file.name] = {} - self._symbols_by_file[desc.file.name][desc.full_name] = message + self._classes[desc.full_name] = message self.pool.AddDescriptor(desc) return message @@ -136,47 +121,46 @@ class SymbolDatabase(object): KeyError: if the symbol could not be found. """ - return self._symbols[symbol] - - def GetPrototype(self, descriptor): - """Builds a proto2 message class based on the passed in descriptor. - - Passing a descriptor with a fully qualified name matching a previous - invocation will cause the same class to be returned. - - Args: - descriptor: The descriptor to build from. - - Returns: - A class describing the passed in descriptor. - """ - - return self.GetSymbol(descriptor.full_name) + return self._classes[symbol] def GetMessages(self, files): - """Gets all the messages from a specified file. - - This will find and resolve dependencies, failing if they are not registered - in the symbol database. + # TODO(amauryfa): Fix the differences with MessageFactory. + """Gets all registered messages from a specified file. + Only messages already created and registered will be returned; (this is the + case for imported _pb2 modules) + But unlike MessageFactory, this version also returns nested messages. Args: files: The file names to extract messages from. Returns: - A dictionary mapping proto names to the message classes. This will include - any dependent messages as well as any messages defined in the same file as - a specified message. + A dictionary mapping proto names to the message classes. Raises: KeyError: if a file could not be found. """ + def _GetAllMessageNames(desc): + """Walk a message Descriptor and recursively yields all message names.""" + yield desc.full_name + for msg_desc in desc.nested_types: + for full_name in _GetAllMessageNames(msg_desc): + yield full_name + result = {} - for f in files: - result.update(self._symbols_by_file[f]) + for file_name in files: + file_desc = self.pool.FindFileByName(file_name) + for msg_desc in file_desc.message_types_by_name.values(): + for full_name in _GetAllMessageNames(msg_desc): + try: + result[full_name] = self._classes[full_name] + except KeyError: + # This descriptor has no registered class, skip it. + pass return result + _DEFAULT = SymbolDatabase(pool=descriptor_pool.Default()) |