aboutsummaryrefslogtreecommitdiffhomepage
path: root/python/google/protobuf/pyext/message.cc
diff options
context:
space:
mode:
Diffstat (limited to 'python/google/protobuf/pyext/message.cc')
-rw-r--r--python/google/protobuf/pyext/message.cc153
1 files changed, 92 insertions, 61 deletions
diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc
index 62c7c478..63d53136 100644
--- a/python/google/protobuf/pyext/message.cc
+++ b/python/google/protobuf/pyext/message.cc
@@ -55,6 +55,7 @@
#include <google/protobuf/descriptor.h>
#include <google/protobuf/message.h>
#include <google/protobuf/text_format.h>
+#include <google/protobuf/unknown_field_set.h>
#include <google/protobuf/pyext/descriptor.h>
#include <google/protobuf/pyext/descriptor_pool.h>
#include <google/protobuf/pyext/extension_dict.h>
@@ -107,8 +108,18 @@ struct PyMessageMeta {
// C++ descriptor of this message.
const Descriptor* message_descriptor;
+
// Owned reference, used to keep the pointer above alive.
PyObject* py_message_descriptor;
+
+ // The Python DescriptorPool used to create the class. It is needed to resolve
+ // fields descriptors, including extensions fields; its C++ MessageFactory is
+ // used to instantiate submessages.
+ // This can be different from DESCRIPTOR.file.pool, in the case of a custom
+ // DescriptorPool which defines new extensions.
+ // We own the reference, because it's important to keep the descriptors and
+ // factory alive.
+ PyDescriptorPool* py_descriptor_pool;
};
namespace message_meta {
@@ -139,18 +150,10 @@ static bool AddFieldNumberToClass(
// Finalize the creation of the Message class.
-// Called from its metaclass: GeneratedProtocolMessageType.__init__().
-static int AddDescriptors(PyObject* cls, PyObject* descriptor) {
- const Descriptor* message_descriptor =
- cdescriptor_pool::RegisterMessageClass(
- GetDescriptorPool(), cls, descriptor);
- if (message_descriptor == NULL) {
- return -1;
- }
-
+static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) {
// If there are extension_ranges, the message is "extendable", and extension
// classes will register themselves in this class.
- if (message_descriptor->extension_range_count() > 0) {
+ if (descriptor->extension_range_count() > 0) {
ScopedPyObjectPtr by_name(PyDict_New());
if (PyObject_SetAttr(cls, k_extensions_by_name, by_name) < 0) {
return -1;
@@ -162,8 +165,8 @@ static int AddDescriptors(PyObject* cls, PyObject* descriptor) {
}
// For each field set: cls.<field>_FIELD_NUMBER = <number>
- for (int i = 0; i < message_descriptor->field_count(); ++i) {
- if (!AddFieldNumberToClass(cls, message_descriptor->field(i))) {
+ for (int i = 0; i < descriptor->field_count(); ++i) {
+ if (!AddFieldNumberToClass(cls, descriptor->field(i))) {
return -1;
}
}
@@ -173,8 +176,8 @@ static int AddDescriptors(PyObject* cls, PyObject* descriptor) {
// The enum descriptor we get from
// <messagedescriptor>.enum_types_by_name[name]
// which was built previously.
- for (int i = 0; i < message_descriptor->enum_type_count(); ++i) {
- const EnumDescriptor* enum_descriptor = message_descriptor->enum_type(i);
+ for (int i = 0; i < descriptor->enum_type_count(); ++i) {
+ const EnumDescriptor* enum_descriptor = descriptor->enum_type(i);
ScopedPyObjectPtr enum_type(
PyEnumDescriptor_FromDescriptor(enum_descriptor));
if (enum_type == NULL) {
@@ -212,8 +215,8 @@ static int AddDescriptors(PyObject* cls, PyObject* descriptor) {
// Extension descriptors come from
// <message descriptor>.extensions_by_name[name]
// which was defined previously.
- for (int i = 0; i < message_descriptor->extension_count(); ++i) {
- const google::protobuf::FieldDescriptor* field = message_descriptor->extension(i);
+ for (int i = 0; i < descriptor->extension_count(); ++i) {
+ const google::protobuf::FieldDescriptor* field = descriptor->extension(i);
ScopedPyObjectPtr extension_field(PyFieldDescriptor_FromDescriptor(field));
if (extension_field == NULL) {
return -1;
@@ -258,14 +261,14 @@ static PyObject* New(PyTypeObject* type,
}
// Check dict['DESCRIPTOR']
- PyObject* descriptor = PyDict_GetItem(dict, kDESCRIPTOR);
- if (descriptor == NULL) {
+ PyObject* py_descriptor = PyDict_GetItem(dict, kDESCRIPTOR);
+ if (py_descriptor == NULL) {
PyErr_SetString(PyExc_TypeError, "Message class has no DESCRIPTOR");
return NULL;
}
- if (!PyObject_TypeCheck(descriptor, &PyMessageDescriptor_Type)) {
+ if (!PyObject_TypeCheck(py_descriptor, &PyMessageDescriptor_Type)) {
PyErr_Format(PyExc_TypeError, "Expected a message Descriptor, got %s",
- descriptor->ob_type->tp_name);
+ py_descriptor->ob_type->tp_name);
return NULL;
}
@@ -291,14 +294,28 @@ static PyObject* New(PyTypeObject* type,
}
// Cache the descriptor, both as Python object and as C++ pointer.
- const Descriptor* message_descriptor =
- PyMessageDescriptor_AsDescriptor(descriptor);
- if (message_descriptor == NULL) {
+ const Descriptor* descriptor =
+ PyMessageDescriptor_AsDescriptor(py_descriptor);
+ if (descriptor == NULL) {
+ return NULL;
+ }
+ Py_INCREF(py_descriptor);
+ newtype->py_message_descriptor = py_descriptor;
+ newtype->message_descriptor = descriptor;
+ // TODO(amauryfa): Don't always use the canonical pool of the descriptor,
+ // use the MessageFactory optionally passed in the class dict.
+ newtype->py_descriptor_pool = GetDescriptorPool_FromPool(
+ descriptor->file()->pool());
+ if (newtype->py_descriptor_pool == NULL) {
+ return NULL;
+ }
+ Py_INCREF(newtype->py_descriptor_pool);
+
+ // Add the message to the DescriptorPool.
+ if (cdescriptor_pool::RegisterMessageClass(newtype->py_descriptor_pool,
+ descriptor, result) < 0) {
return NULL;
}
- Py_INCREF(descriptor);
- newtype->py_message_descriptor = descriptor;
- newtype->message_descriptor = message_descriptor;
// Continue with type initialization: add other descriptors, enum values...
if (AddDescriptors(result, descriptor) < 0) {
@@ -309,6 +326,7 @@ static PyObject* New(PyTypeObject* type,
static void Dealloc(PyMessageMeta *self) {
Py_DECREF(self->py_message_descriptor);
+ Py_DECREF(self->py_descriptor_pool);
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
}
@@ -381,12 +399,20 @@ PyTypeObject PyMessageMeta_Type = {
message_meta::New, // tp_new
};
-static const Descriptor* GetMessageDescriptor(PyTypeObject* cls) {
+static PyMessageMeta* CheckMessageClass(PyTypeObject* cls) {
if (!PyObject_TypeCheck(cls, &PyMessageMeta_Type)) {
PyErr_Format(PyExc_TypeError, "Class %s is not a Message", cls->tp_name);
return NULL;
}
- return reinterpret_cast<PyMessageMeta*>(cls)->message_descriptor;
+ return reinterpret_cast<PyMessageMeta*>(cls);
+}
+
+static const Descriptor* GetMessageDescriptor(PyTypeObject* cls) {
+ PyMessageMeta* type = CheckMessageClass(cls);
+ if (type == NULL) {
+ return NULL;
+ }
+ return type->message_descriptor;
}
// Forward declarations
@@ -723,6 +749,17 @@ bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor,
namespace cmessage {
+PyDescriptorPool* GetDescriptorPoolForMessage(CMessage* message) {
+ // No need to check the type: the type of instances of CMessage is always
+ // an instance of PyMessageMeta. Let's prove it with a debug-only check.
+ GOOGLE_DCHECK(PyObject_TypeCheck(message, &CMessage_Type));
+ return reinterpret_cast<PyMessageMeta*>(Py_TYPE(message))->py_descriptor_pool;
+}
+
+MessageFactory* GetFactoryForMessage(CMessage* message) {
+ return GetDescriptorPoolForMessage(message)->message_factory;
+}
+
static int MaybeReleaseOverlappingOneofField(
CMessage* cmessage,
const FieldDescriptor* field) {
@@ -773,7 +810,7 @@ static Message* GetMutableMessage(
return NULL;
}
return reflection->MutableMessage(
- parent_message, parent_field, GetDescriptorPool()->message_factory);
+ parent_message, parent_field, GetFactoryForMessage(parent));
}
struct FixupMessageReference : public ChildVisitor {
@@ -814,10 +851,7 @@ int AssureWritable(CMessage* self) {
// If parent is NULL but we are trying to modify a read-only message, this
// is a reference to a constant default instance that needs to be replaced
// with a mutable top-level message.
- const Message* prototype =
- GetDescriptorPool()->message_factory->GetPrototype(
- self->message->GetDescriptor());
- self->message = prototype->New();
+ self->message = self->message->New();
self->owner.reset(self->message);
// Cascade the new owner to eventual children: even if this message is
// empty, some submessages or repeated containers might exist already.
@@ -1190,15 +1224,19 @@ CMessage* NewEmptyMessage(PyObject* type, const Descriptor *descriptor) {
// The __new__ method of Message classes.
// Creates a new C++ message and takes ownership.
-static PyObject* New(PyTypeObject* type,
+static PyObject* New(PyTypeObject* cls,
PyObject* unused_args, PyObject* unused_kwargs) {
+ PyMessageMeta* type = CheckMessageClass(cls);
+ if (type == NULL) {
+ return NULL;
+ }
// Retrieve the message descriptor and the default instance (=prototype).
- const Descriptor* message_descriptor = GetMessageDescriptor(type);
+ const Descriptor* message_descriptor = type->message_descriptor;
if (message_descriptor == NULL) {
return NULL;
}
- const Message* default_message =
- GetDescriptorPool()->message_factory->GetPrototype(message_descriptor);
+ const Message* default_message = type->py_descriptor_pool->message_factory
+ ->GetPrototype(message_descriptor);
if (default_message == NULL) {
PyErr_SetString(PyExc_TypeError, message_descriptor->full_name().c_str());
return NULL;
@@ -1528,7 +1566,7 @@ int SetOwner(CMessage* self, const shared_ptr<Message>& new_owner) {
Message* ReleaseMessage(CMessage* self,
const Descriptor* descriptor,
const FieldDescriptor* field_descriptor) {
- MessageFactory* message_factory = GetDescriptorPool()->message_factory;
+ MessageFactory* message_factory = GetFactoryForMessage(self);
Message* released_message = self->message->GetReflection()->ReleaseMessage(
self->message, field_descriptor, message_factory);
// ReleaseMessage will return NULL which differs from
@@ -1883,8 +1921,8 @@ static PyObject* MergeFromString(CMessage* self, PyObject* arg) {
AssureWritable(self);
io::CodedInputStream input(
reinterpret_cast<const uint8*>(data), data_length);
- input.SetExtensionRegistry(GetDescriptorPool()->pool,
- GetDescriptorPool()->message_factory);
+ PyDescriptorPool* pool = GetDescriptorPoolForMessage(self);
+ input.SetExtensionRegistry(pool->pool, pool->message_factory);
bool success = self->message->MergePartialFromCodedStream(&input);
if (success) {
return PyInt_FromLong(input.CurrentPosition());
@@ -1907,11 +1945,6 @@ static PyObject* ByteSize(CMessage* self, PyObject* args) {
static PyObject* RegisterExtension(PyObject* cls,
PyObject* extension_handle) {
- ScopedPyObjectPtr message_descriptor(PyObject_GetAttr(cls, kDESCRIPTOR));
- if (message_descriptor == NULL) {
- return NULL;
- }
-
const FieldDescriptor* descriptor =
GetExtensionDescriptor(extension_handle);
if (descriptor == NULL) {
@@ -1920,13 +1953,6 @@ static PyObject* RegisterExtension(PyObject* cls,
const Descriptor* cmessage_descriptor = GetMessageDescriptor(
reinterpret_cast<PyTypeObject*>(cls));
- if (cmessage_descriptor != descriptor->containing_type()) {
- if (PyObject_SetAttrString(extension_handle, "containing_type",
- message_descriptor) < 0) {
- return NULL;
- }
- }
-
ScopedPyObjectPtr extensions_by_name(
PyObject_GetAttr(cls, k_extensions_by_name));
if (extensions_by_name == NULL) {
@@ -2050,7 +2076,8 @@ static PyObject* ListFields(CMessage* self) {
// TODO(amauryfa): consider building the class on the fly!
if (fields[i]->message_type() != NULL &&
cdescriptor_pool::GetMessageClass(
- GetDescriptorPool(), fields[i]->message_type()) == NULL) {
+ GetDescriptorPoolForMessage(self),
+ fields[i]->message_type()) == NULL) {
PyErr_Clear();
continue;
}
@@ -2207,7 +2234,9 @@ PyObject* InternalGetScalar(const Message* message,
message->GetReflection()->GetUnknownFields(*message);
for (int i = 0; i < unknown_field_set.field_count(); ++i) {
if (unknown_field_set.field(i).number() ==
- field_descriptor->number()) {
+ field_descriptor->number() &&
+ unknown_field_set.field(i).type() ==
+ google::protobuf::UnknownField::TYPE_VARINT) {
result = PyInt_FromLong(unknown_field_set.field(i).varint());
break;
}
@@ -2233,11 +2262,12 @@ PyObject* InternalGetScalar(const Message* message,
PyObject* InternalGetSubMessage(
CMessage* self, const FieldDescriptor* field_descriptor) {
const Reflection* reflection = self->message->GetReflection();
+ PyDescriptorPool* pool = GetDescriptorPoolForMessage(self);
const Message& sub_message = reflection->GetMessage(
- *self->message, field_descriptor, GetDescriptorPool()->message_factory);
+ *self->message, field_descriptor, pool->message_factory);
PyObject *message_class = cdescriptor_pool::GetMessageClass(
- GetDescriptorPool(), field_descriptor->message_type());
+ pool, field_descriptor->message_type());
if (message_class == NULL) {
return NULL;
}
@@ -2560,7 +2590,7 @@ PyObject* GetAttr(CMessage* self, PyObject* name) {
const FieldDescriptor* value_type = entry_type->FindFieldByName("value");
if (value_type->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
PyObject* value_class = cdescriptor_pool::GetMessageClass(
- GetDescriptorPool(), value_type->message_type());
+ GetDescriptorPoolForMessage(self), value_type->message_type());
if (value_class == NULL) {
return NULL;
}
@@ -2583,7 +2613,7 @@ PyObject* GetAttr(CMessage* self, PyObject* name) {
PyObject* py_container = NULL;
if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
PyObject *message_class = cdescriptor_pool::GetMessageClass(
- GetDescriptorPool(), field_descriptor->message_type());
+ GetDescriptorPoolForMessage(self), field_descriptor->message_type());
if (message_class == NULL) {
return NULL;
}
@@ -2908,9 +2938,10 @@ bool InitProto2MessageModule(PyObject *m) {
// Expose the DescriptorPool used to hold all descriptors added from generated
// pb2.py files.
- Py_INCREF(GetDescriptorPool()); // PyModule_AddObject steals a reference.
- PyModule_AddObject(
- m, "default_pool", reinterpret_cast<PyObject*>(GetDescriptorPool()));
+ // PyModule_AddObject steals a reference.
+ Py_INCREF(GetDefaultDescriptorPool());
+ PyModule_AddObject(m, "default_pool",
+ reinterpret_cast<PyObject*>(GetDefaultDescriptorPool()));
// This implementation provides full Descriptor types, we advertise it so that
// descriptor.py can use them in replacement of the Python classes.