aboutsummaryrefslogtreecommitdiffhomepage
path: root/python/google/protobuf
diff options
context:
space:
mode:
authorGravatar Feng Xiao <xfxyjwf@gmail.com>2016-07-13 13:47:51 -0700
committerGravatar Feng Xiao <xfxyjwf@gmail.com>2016-07-13 13:48:40 -0700
commit9086d9643903c608ab015b0b7d903547a4e7b6f3 (patch)
treeb47053ab6f6bde20b55c4fff4019c68a7c45545c /python/google/protobuf
parent70c1ac756d3cd8fa04725f82f0ad1a30404c3bb3 (diff)
Integrate from internal code base.
Diffstat (limited to 'python/google/protobuf')
-rw-r--r--python/google/protobuf/internal/json_format_test.py6
-rwxr-xr-xpython/google/protobuf/internal/python_message.py26
-rwxr-xr-xpython/google/protobuf/internal/reflection_test.py1
-rw-r--r--python/google/protobuf/internal/symbol_database_test.py26
-rw-r--r--python/google/protobuf/pyext/cpp_message.py6
-rw-r--r--python/google/protobuf/pyext/map_container.cc3
-rw-r--r--python/google/protobuf/pyext/message.cc16
-rw-r--r--python/google/protobuf/pyext/message.h4
-rw-r--r--python/google/protobuf/pyext/repeated_composite_container.cc4
-rwxr-xr-xpython/google/protobuf/reflection.py8
-rw-r--r--python/google/protobuf/symbol_database.py82
11 files changed, 82 insertions, 100 deletions
diff --git a/python/google/protobuf/internal/json_format_test.py b/python/google/protobuf/internal/json_format_test.py
index 6df12bea..a5ee8ace 100644
--- a/python/google/protobuf/internal/json_format_test.py
+++ b/python/google/protobuf/internal/json_format_test.py
@@ -252,10 +252,7 @@ class JsonFormatTest(JsonFormatBase):
message = json_format_proto3_pb2.TestMessage()
json_format.Parse('{"stringValue": "\\uD83D\\uDE01"}', message)
self.assertEqual(message.string_value,
- b'\xF0\x9F\x98\x81'.decode("utf-8", "strict"))
-
- # TODO: add test that UTF-8 encoded surrogate code points are rejected.
- # UTF-8 does not allow them.
+ b'\xF0\x9F\x98\x81'.decode('utf-8', 'strict'))
# Error case: unpaired high surrogate.
self.CheckError(
@@ -267,7 +264,6 @@ class JsonFormatTest(JsonFormatBase):
'{"stringValue": "\\uDE01"}',
r'Invalid \\uXXXX escape|Unpaired.*surrogate')
-
def testTimestampMessage(self):
message = json_format_proto3_pb2.TestTimestamp()
message.value.seconds = 0
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py
index f8f73dd2..c0d0ad45 100755
--- a/python/google/protobuf/internal/python_message.py
+++ b/python/google/protobuf/internal/python_message.py
@@ -76,7 +76,6 @@ from google.protobuf.internal import well_known_types
from google.protobuf.internal import wire_format
from google.protobuf import descriptor as descriptor_mod
from google.protobuf import message as message_mod
-from google.protobuf import symbol_database
from google.protobuf import text_format
_FieldDescriptor = descriptor_mod.FieldDescriptor
@@ -98,16 +97,12 @@ class GeneratedProtocolMessageType(type):
classes at runtime, as in this example:
mydescriptor = Descriptor(.....)
- class MyProtoClass(Message):
- __metaclass__ = GeneratedProtocolMessageType
- DESCRIPTOR = mydescriptor
+ factory = symbol_database.Default()
+ factory.pool.AddDescriptor(mydescriptor)
+ MyProtoClass = factory.GetPrototype(mydescriptor)
myproto_instance = MyProtoClass()
myproto.foo_field = 23
...
-
- The above example will not work for nested types. If you wish to include them,
- use reflection.MakeClass() instead of manually instantiating the class in
- order to create the appropriate class structure.
"""
# Must be consistent with the protocol-compiler code in
@@ -926,26 +921,33 @@ def _InternalUnpackAny(msg):
Returns:
The unpacked message.
"""
+ # TODO(amauryfa): Don't use the factory of generated messages.
+ # To make Any work with custom factories, use the message factory of the
+ # parent message.
+ # pylint: disable=g-import-not-at-top
+ from google.protobuf import symbol_database
+ factory = symbol_database.Default()
+
type_url = msg.type_url
- db = symbol_database.Default()
if not type_url:
return None
# TODO(haberman): For now we just strip the hostname. Better logic will be
# required.
- type_name = type_url.split("/")[-1]
- descriptor = db.pool.FindMessageTypeByName(type_name)
+ type_name = type_url.split('/')[-1]
+ descriptor = factory.pool.FindMessageTypeByName(type_name)
if descriptor is None:
return None
- message_class = db.GetPrototype(descriptor)
+ message_class = factory.GetPrototype(descriptor)
message = message_class()
message.ParseFromString(msg.value)
return message
+
def _AddEqualsMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
def __eq__(self, other):
diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py
index 6dc2fffe..20e5d245 100755
--- a/python/google/protobuf/internal/reflection_test.py
+++ b/python/google/protobuf/internal/reflection_test.py
@@ -972,6 +972,7 @@ class ReflectionTest(unittest.TestCase):
proto.repeated_nested_message.add(bb=23)
self.assertEqual(1, len(proto.repeated_nested_message))
self.assertEqual(23, proto.repeated_nested_message[0].bb)
+ self.assertRaises(TypeError, proto.repeated_nested_message.add, 23)
def testRepeatedCompositeRemove(self):
proto = unittest_pb2.TestAllTypes()
diff --git a/python/google/protobuf/internal/symbol_database_test.py b/python/google/protobuf/internal/symbol_database_test.py
index c99b426d..4f5173b2 100644
--- a/python/google/protobuf/internal/symbol_database_test.py
+++ b/python/google/protobuf/internal/symbol_database_test.py
@@ -39,26 +39,28 @@ except ImportError:
from google.protobuf import unittest_pb2
from google.protobuf import descriptor
+from google.protobuf import descriptor_pool
from google.protobuf import symbol_database
+
class SymbolDatabaseTest(unittest.TestCase):
def _Database(self):
- # TODO(b/17734095): Remove this difference when the C++ implementation
- # supports multiple databases.
if descriptor._USE_C_DESCRIPTORS:
- return symbol_database.Default()
+ # The C++ implementation does not allow mixing descriptors from
+ # different pools.
+ db = symbol_database.SymbolDatabase(pool=descriptor_pool.Default())
else:
db = symbol_database.SymbolDatabase()
- # Register representative types from unittest_pb2.
- db.RegisterFileDescriptor(unittest_pb2.DESCRIPTOR)
- db.RegisterMessage(unittest_pb2.TestAllTypes)
- db.RegisterMessage(unittest_pb2.TestAllTypes.NestedMessage)
- db.RegisterMessage(unittest_pb2.TestAllTypes.OptionalGroup)
- db.RegisterMessage(unittest_pb2.TestAllTypes.RepeatedGroup)
- db.RegisterEnumDescriptor(unittest_pb2.ForeignEnum.DESCRIPTOR)
- db.RegisterEnumDescriptor(unittest_pb2.TestAllTypes.NestedEnum.DESCRIPTOR)
- return db
+ # Register representative types from unittest_pb2.
+ db.RegisterFileDescriptor(unittest_pb2.DESCRIPTOR)
+ db.RegisterMessage(unittest_pb2.TestAllTypes)
+ db.RegisterMessage(unittest_pb2.TestAllTypes.NestedMessage)
+ db.RegisterMessage(unittest_pb2.TestAllTypes.OptionalGroup)
+ db.RegisterMessage(unittest_pb2.TestAllTypes.RepeatedGroup)
+ db.RegisterEnumDescriptor(unittest_pb2.ForeignEnum.DESCRIPTOR)
+ db.RegisterEnumDescriptor(unittest_pb2.TestAllTypes.NestedEnum.DESCRIPTOR)
+ return db
def testGetPrototype(self):
instance = self._Database().GetPrototype(
diff --git a/python/google/protobuf/pyext/cpp_message.py b/python/google/protobuf/pyext/cpp_message.py
index b215211e..fc8eb32d 100644
--- a/python/google/protobuf/pyext/cpp_message.py
+++ b/python/google/protobuf/pyext/cpp_message.py
@@ -48,9 +48,9 @@ class GeneratedProtocolMessageType(_message.MessageMeta):
classes at runtime, as in this example:
mydescriptor = Descriptor(.....)
- class MyProtoClass(Message):
- __metaclass__ = GeneratedProtocolMessageType
- DESCRIPTOR = mydescriptor
+ factory = symbol_database.Default()
+ factory.pool.AddDescriptor(mydescriptor)
+ MyProtoClass = factory.GetPrototype(mydescriptor)
myproto_instance = MyProtoClass()
myproto.foo_field = 23
...
diff --git a/python/google/protobuf/pyext/map_container.cc b/python/google/protobuf/pyext/map_container.cc
index 90438df1..0987b898 100644
--- a/python/google/protobuf/pyext/map_container.cc
+++ b/python/google/protobuf/pyext/map_container.cc
@@ -348,9 +348,10 @@ PyObject* MapReflectionFriend::Contains(PyObject* _self, PyObject* key) {
}
// Initializes the underlying Message object of "to" so it becomes a new parent
-// repeated scalar, and copies all the values from "from" to it. A child scalar
+// map container, and copies all the values from "from" to it. A child map
// container can be released by passing it as both from and to (e.g. making it
// the recipient of the new parent message and copying the values from itself).
+// In fact, this is the only supported use at the moment.
static int InitializeAndCopyToParentContainer(MapContainer* from,
MapContainer* to) {
// For now we require from == to, re-evaluate if we want to support deep copy
diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc
index a9261f20..5535338d 100644
--- a/python/google/protobuf/pyext/message.cc
+++ b/python/google/protobuf/pyext/message.cc
@@ -1041,7 +1041,12 @@ int InternalDeleteRepeatedField(
}
// Initializes fields of a message. Used in constructors.
-int InitAttributes(CMessage* self, PyObject* kwargs) {
+int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) {
+ if (args != NULL && PyTuple_Size(args) != 0) {
+ PyErr_SetString(PyExc_TypeError, "No positional arguments allowed");
+ return -1;
+ }
+
if (kwargs == NULL) {
return 0;
}
@@ -1167,7 +1172,7 @@ int InitAttributes(CMessage* self, PyObject* kwargs) {
}
CMessage* cmessage = reinterpret_cast<CMessage*>(message.get());
if (PyDict_Check(value)) {
- if (InitAttributes(cmessage, value) < 0) {
+ if (InitAttributes(cmessage, NULL, value) < 0) {
return -1;
}
} else {
@@ -1245,12 +1250,7 @@ static PyObject* New(PyTypeObject* cls,
// The __init__ method of Message classes.
// It initializes fields from keywords passed to the constructor.
static int Init(CMessage* self, PyObject* args, PyObject* kwargs) {
- if (PyTuple_Size(args) != 0) {
- PyErr_SetString(PyExc_TypeError, "No positional arguments allowed");
- return -1;
- }
-
- return InitAttributes(self, kwargs);
+ return InitAttributes(self, args, kwargs);
}
// ---------------------------------------------------------------------
diff --git a/python/google/protobuf/pyext/message.h b/python/google/protobuf/pyext/message.h
index 8b399e05..c44a2ae2 100644
--- a/python/google/protobuf/pyext/message.h
+++ b/python/google/protobuf/pyext/message.h
@@ -237,7 +237,9 @@ PyObject* HasFieldByDescriptor(
PyObject* HasField(CMessage* self, PyObject* arg);
// Initializes values of fields on a newly constructed message.
-int InitAttributes(CMessage* self, PyObject* kwargs);
+// Note that positional arguments are disallowed: 'args' must be NULL or the
+// empty tuple.
+int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs);
PyObject* MergeFrom(CMessage* self, PyObject* arg);
diff --git a/python/google/protobuf/pyext/repeated_composite_container.cc b/python/google/protobuf/pyext/repeated_composite_container.cc
index 4f339e77..bb2f6db2 100644
--- a/python/google/protobuf/pyext/repeated_composite_container.cc
+++ b/python/google/protobuf/pyext/repeated_composite_container.cc
@@ -146,7 +146,7 @@ static PyObject* AddToAttached(RepeatedCompositeContainer* self,
cmsg->owner = self->owner;
cmsg->message = sub_message;
cmsg->parent = self->parent;
- if (cmessage::InitAttributes(cmsg, kwargs) < 0) {
+ if (cmessage::InitAttributes(cmsg, args, kwargs) < 0) {
Py_DECREF(cmsg);
return NULL;
}
@@ -166,7 +166,7 @@ static PyObject* AddToReleased(RepeatedCompositeContainer* self,
// Create a new Message detached from the rest.
PyObject* py_cmsg = PyEval_CallObjectWithKeywords(
- self->child_message_class->AsPyObject(), NULL, kwargs);
+ self->child_message_class->AsPyObject(), args, kwargs);
if (py_cmsg == NULL)
return NULL;
diff --git a/python/google/protobuf/reflection.py b/python/google/protobuf/reflection.py
index 0c757264..51c83321 100755
--- a/python/google/protobuf/reflection.py
+++ b/python/google/protobuf/reflection.py
@@ -58,13 +58,7 @@ else:
from google.protobuf.internal import python_message as message_impl
# The type of all Message classes.
-# Part of the public interface.
-#
-# Used by generated files, but clients can also use it at runtime:
-# mydescriptor = pool.FindDescriptor(.....)
-# class MyProtoClass(Message):
-# __metaclass__ = GeneratedProtocolMessageType
-# DESCRIPTOR = mydescriptor
+# Part of the public interface, but normally only used by message factories.
GeneratedProtocolMessageType = message_impl.GeneratedProtocolMessageType
diff --git a/python/google/protobuf/symbol_database.py b/python/google/protobuf/symbol_database.py
index 87760f26..aa466abd 100644
--- a/python/google/protobuf/symbol_database.py
+++ b/python/google/protobuf/symbol_database.py
@@ -30,11 +30,9 @@
"""A database of Python protocol buffer generated symbols.
-SymbolDatabase makes it easy to create new instances of a registered type, given
-only the type's protocol buffer symbol name. Once all symbols are registered,
-they can be accessed using either the MessageFactory interface which
-SymbolDatabase exposes, or the DescriptorPool interface of the underlying
-pool.
+SymbolDatabase is the MessageFactory for messages generated at compile time,
+and makes it easy to create new instances of a registered type, given only the
+type's protocol buffer symbol name.
Example usage:
@@ -61,27 +59,17 @@ Example usage:
from google.protobuf import descriptor_pool
+from google.protobuf import message_factory
-class SymbolDatabase(object):
- """A database of Python generated symbols.
-
- SymbolDatabase also models message_factory.MessageFactory.
-
- The symbol database can be used to keep a global registry of all protocol
- buffer types used within a program.
- """
-
- def __init__(self, pool=None):
- """Constructor."""
-
- self._symbols = {}
- self._symbols_by_file = {}
- self.pool = pool or descriptor_pool.Default()
+class SymbolDatabase(message_factory.MessageFactory):
+ """A database of Python generated symbols."""
def RegisterMessage(self, message):
"""Registers the given message type in the local database.
+ Calls to GetSymbol() and GetMessages() will return messages registered here.
+
Args:
message: a message.Message, to be registered.
@@ -90,10 +78,7 @@ class SymbolDatabase(object):
"""
desc = message.DESCRIPTOR
- self._symbols[desc.full_name] = message
- if desc.file.name not in self._symbols_by_file:
- self._symbols_by_file[desc.file.name] = {}
- self._symbols_by_file[desc.file.name][desc.full_name] = message
+ self._classes[desc.full_name] = message
self.pool.AddDescriptor(desc)
return message
@@ -136,47 +121,46 @@ class SymbolDatabase(object):
KeyError: if the symbol could not be found.
"""
- return self._symbols[symbol]
-
- def GetPrototype(self, descriptor):
- """Builds a proto2 message class based on the passed in descriptor.
-
- Passing a descriptor with a fully qualified name matching a previous
- invocation will cause the same class to be returned.
-
- Args:
- descriptor: The descriptor to build from.
-
- Returns:
- A class describing the passed in descriptor.
- """
-
- return self.GetSymbol(descriptor.full_name)
+ return self._classes[symbol]
def GetMessages(self, files):
- """Gets all the messages from a specified file.
-
- This will find and resolve dependencies, failing if they are not registered
- in the symbol database.
+ # TODO(amauryfa): Fix the differences with MessageFactory.
+ """Gets all registered messages from a specified file.
+ Only messages already created and registered will be returned; (this is the
+ case for imported _pb2 modules)
+ But unlike MessageFactory, this version also returns nested messages.
Args:
files: The file names to extract messages from.
Returns:
- A dictionary mapping proto names to the message classes. This will include
- any dependent messages as well as any messages defined in the same file as
- a specified message.
+ A dictionary mapping proto names to the message classes.
Raises:
KeyError: if a file could not be found.
"""
+ def _GetAllMessageNames(desc):
+ """Walk a message Descriptor and recursively yields all message names."""
+ yield desc.full_name
+ for msg_desc in desc.nested_types:
+ for full_name in _GetAllMessageNames(msg_desc):
+ yield full_name
+
result = {}
- for f in files:
- result.update(self._symbols_by_file[f])
+ for file_name in files:
+ file_desc = self.pool.FindFileByName(file_name)
+ for msg_desc in file_desc.message_types_by_name.values():
+ for full_name in _GetAllMessageNames(msg_desc):
+ try:
+ result[full_name] = self._classes[full_name]
+ except KeyError:
+ # This descriptor has no registered class, skip it.
+ pass
return result
+
_DEFAULT = SymbolDatabase(pool=descriptor_pool.Default())