aboutsummaryrefslogtreecommitdiffhomepage
path: root/python/google/protobuf/pyext/message_factory.cc
diff options
context:
space:
mode:
Diffstat (limited to 'python/google/protobuf/pyext/message_factory.cc')
-rw-r--r--python/google/protobuf/pyext/message_factory.cc66
1 files changed, 66 insertions, 0 deletions
diff --git a/python/google/protobuf/pyext/message_factory.cc b/python/google/protobuf/pyext/message_factory.cc
index 2ad89022..e0b45bf2 100644
--- a/python/google/protobuf/pyext/message_factory.cc
+++ b/python/google/protobuf/pyext/message_factory.cc
@@ -130,6 +130,72 @@ int RegisterMessageClass(PyMessageFactory* self,
return 0;
}
+CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self,
+ const Descriptor* descriptor) {
+ // This is the same implementation as MessageFactory.GetPrototype().
+ ScopedPyObjectPtr py_descriptor(
+ PyMessageDescriptor_FromDescriptor(descriptor));
+ if (py_descriptor == NULL) {
+ return NULL;
+ }
+ // Do not create a MessageClass that already exists.
+ hash_map<const Descriptor*, CMessageClass*>::iterator it =
+ self->classes_by_descriptor->find(descriptor);
+ if (it != self->classes_by_descriptor->end()) {
+ Py_INCREF(it->second);
+ return it->second;
+ }
+ // Create a new message class.
+ ScopedPyObjectPtr args(Py_BuildValue(
+ "s(){sOsOsO}", descriptor->name().c_str(),
+ "DESCRIPTOR", py_descriptor.get(),
+ "__module__", Py_None,
+ "message_factory", self));
+ if (args == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr message_class(PyObject_CallObject(
+ reinterpret_cast<PyObject*>(&CMessageClass_Type), args.get()));
+ if (message_class == NULL) {
+ return NULL;
+ }
+ // Create messages class for the messages used by the fields, and registers
+ // all extensions for these messages during the recursion.
+ for (int field_idx = 0; field_idx < descriptor->field_count(); field_idx++) {
+ const Descriptor* sub_descriptor =
+ descriptor->field(field_idx)->message_type();
+ // It is NULL if the field type is not a message.
+ if (sub_descriptor != NULL) {
+ CMessageClass* result = GetOrCreateMessageClass(self, sub_descriptor);
+ if (result == NULL) {
+ return NULL;
+ }
+ Py_DECREF(result);
+ }
+ }
+
+ // Register extensions defined in this message.
+ for (int ext_idx = 0 ; ext_idx < descriptor->extension_count() ; ext_idx++) {
+ const FieldDescriptor* extension = descriptor->extension(ext_idx);
+ ScopedPyObjectPtr py_extended_class(
+ GetOrCreateMessageClass(self, extension->containing_type())
+ ->AsPyObject());
+ if (py_extended_class == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr py_extension(PyFieldDescriptor_FromDescriptor(extension));
+ if (py_extension == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr result(cmessage::RegisterExtension(
+ py_extended_class.get(), py_extension.get()));
+ if (result == NULL) {
+ return NULL;
+ }
+ }
+ return reinterpret_cast<CMessageClass*>(message_class.release());
+}
+
// Retrieve the message class added to our database.
CMessageClass* GetMessageClass(PyMessageFactory* self,
const Descriptor* message_descriptor) {