diff options
Diffstat (limited to 'python/google/protobuf/pyext/message_factory.cc')
-rw-r--r-- | python/google/protobuf/pyext/message_factory.cc | 66 |
1 files changed, 66 insertions, 0 deletions
diff --git a/python/google/protobuf/pyext/message_factory.cc b/python/google/protobuf/pyext/message_factory.cc index 2ad89022..e0b45bf2 100644 --- a/python/google/protobuf/pyext/message_factory.cc +++ b/python/google/protobuf/pyext/message_factory.cc @@ -130,6 +130,72 @@ int RegisterMessageClass(PyMessageFactory* self, return 0; } +CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self, + const Descriptor* descriptor) { + // This is the same implementation as MessageFactory.GetPrototype(). + ScopedPyObjectPtr py_descriptor( + PyMessageDescriptor_FromDescriptor(descriptor)); + if (py_descriptor == NULL) { + return NULL; + } + // Do not create a MessageClass that already exists. + hash_map<const Descriptor*, CMessageClass*>::iterator it = + self->classes_by_descriptor->find(descriptor); + if (it != self->classes_by_descriptor->end()) { + Py_INCREF(it->second); + return it->second; + } + // Create a new message class. + ScopedPyObjectPtr args(Py_BuildValue( + "s(){sOsOsO}", descriptor->name().c_str(), + "DESCRIPTOR", py_descriptor.get(), + "__module__", Py_None, + "message_factory", self)); + if (args == NULL) { + return NULL; + } + ScopedPyObjectPtr message_class(PyObject_CallObject( + reinterpret_cast<PyObject*>(&CMessageClass_Type), args.get())); + if (message_class == NULL) { + return NULL; + } + // Create messages class for the messages used by the fields, and registers + // all extensions for these messages during the recursion. + for (int field_idx = 0; field_idx < descriptor->field_count(); field_idx++) { + const Descriptor* sub_descriptor = + descriptor->field(field_idx)->message_type(); + // It is NULL if the field type is not a message. + if (sub_descriptor != NULL) { + CMessageClass* result = GetOrCreateMessageClass(self, sub_descriptor); + if (result == NULL) { + return NULL; + } + Py_DECREF(result); + } + } + + // Register extensions defined in this message. + for (int ext_idx = 0 ; ext_idx < descriptor->extension_count() ; ext_idx++) { + const FieldDescriptor* extension = descriptor->extension(ext_idx); + ScopedPyObjectPtr py_extended_class( + GetOrCreateMessageClass(self, extension->containing_type()) + ->AsPyObject()); + if (py_extended_class == NULL) { + return NULL; + } + ScopedPyObjectPtr py_extension(PyFieldDescriptor_FromDescriptor(extension)); + if (py_extension == NULL) { + return NULL; + } + ScopedPyObjectPtr result(cmessage::RegisterExtension( + py_extended_class.get(), py_extension.get())); + if (result == NULL) { + return NULL; + } + } + return reinterpret_cast<CMessageClass*>(message_class.release()); +} + // Retrieve the message class added to our database. CMessageClass* GetMessageClass(PyMessageFactory* self, const Descriptor* message_descriptor) { |