diff options
Diffstat (limited to 'python/google/protobuf/pyext/message.cc')
-rw-r--r-- | python/google/protobuf/pyext/message.cc | 724 |
1 files changed, 458 insertions, 266 deletions
diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc index a4843e8d..aa3ab97a 100644 --- a/python/google/protobuf/pyext/message.cc +++ b/python/google/protobuf/pyext/message.cc @@ -49,9 +49,10 @@ #endif #include <google/protobuf/descriptor.pb.h> #include <google/protobuf/stubs/common.h> +#include <google/protobuf/stubs/logging.h> #include <google/protobuf/io/coded_stream.h> +#include <google/protobuf/util/message_differencer.h> #include <google/protobuf/descriptor.h> -#include <google/protobuf/dynamic_message.h> #include <google/protobuf/message.h> #include <google/protobuf/text_format.h> #include <google/protobuf/pyext/descriptor.h> @@ -88,12 +89,308 @@ namespace google { namespace protobuf { namespace python { +static PyObject* kDESCRIPTOR; +static PyObject* k_extensions_by_name; +static PyObject* k_extensions_by_number; +PyObject* EnumTypeWrapper_class; +static PyObject* PythonMessage_class; +static PyObject* kEmptyWeakref; + +// Defines the Metaclass of all Message classes. +// It allows us to cache some C++ pointers in the class object itself, they are +// faster to extract than from the type's dictionary. + +struct PyMessageMeta { + // This is how CPython subclasses C structures: the base structure must be + // the first member of the object. + PyHeapTypeObject super; + + // C++ descriptor of this message. + const Descriptor* message_descriptor; + // Owned reference, used to keep the pointer above alive. + PyObject* py_message_descriptor; +}; + +namespace message_meta { + +static int InsertEmptyWeakref(PyTypeObject* base); + +// Add the number of a field descriptor to the containing message class. +// Equivalent to: +// _cls.<field>_FIELD_NUMBER = <number> +static bool AddFieldNumberToClass( + PyObject* cls, const FieldDescriptor* field_descriptor) { + string constant_name = field_descriptor->name() + "_FIELD_NUMBER"; + UpperString(&constant_name); + ScopedPyObjectPtr attr_name(PyString_FromStringAndSize( + constant_name.c_str(), constant_name.size())); + if (attr_name == NULL) { + return false; + } + ScopedPyObjectPtr number(PyInt_FromLong(field_descriptor->number())); + if (number == NULL) { + return false; + } + if (PyObject_SetAttr(cls, attr_name, number) == -1) { + return false; + } + return true; +} + + +// Finalize the creation of the Message class. +// Called from its metaclass: GeneratedProtocolMessageType.__init__(). +static int AddDescriptors(PyObject* cls, PyObject* descriptor) { + const Descriptor* message_descriptor = + cdescriptor_pool::RegisterMessageClass( + GetDescriptorPool(), cls, descriptor); + if (message_descriptor == NULL) { + return -1; + } + + // If there are extension_ranges, the message is "extendable", and extension + // classes will register themselves in this class. + if (message_descriptor->extension_range_count() > 0) { + ScopedPyObjectPtr by_name(PyDict_New()); + if (PyObject_SetAttr(cls, k_extensions_by_name, by_name) < 0) { + return -1; + } + ScopedPyObjectPtr by_number(PyDict_New()); + if (PyObject_SetAttr(cls, k_extensions_by_number, by_number) < 0) { + return -1; + } + } + + // For each field set: cls.<field>_FIELD_NUMBER = <number> + for (int i = 0; i < message_descriptor->field_count(); ++i) { + if (!AddFieldNumberToClass(cls, message_descriptor->field(i))) { + return -1; + } + } + + // For each enum set cls.<enum name> = EnumTypeWrapper(<enum descriptor>). + // + // The enum descriptor we get from + // <messagedescriptor>.enum_types_by_name[name] + // which was built previously. + for (int i = 0; i < message_descriptor->enum_type_count(); ++i) { + const EnumDescriptor* enum_descriptor = message_descriptor->enum_type(i); + ScopedPyObjectPtr enum_type( + PyEnumDescriptor_FromDescriptor(enum_descriptor)); + if (enum_type == NULL) { + return -1; + } + // Add wrapped enum type to message class. + ScopedPyObjectPtr wrapped(PyObject_CallFunctionObjArgs( + EnumTypeWrapper_class, enum_type.get(), NULL)); + if (wrapped == NULL) { + return -1; + } + if (PyObject_SetAttrString( + cls, enum_descriptor->name().c_str(), wrapped) == -1) { + return -1; + } + + // For each enum value add cls.<name> = <number> + for (int j = 0; j < enum_descriptor->value_count(); ++j) { + const EnumValueDescriptor* enum_value_descriptor = + enum_descriptor->value(j); + ScopedPyObjectPtr value_number(PyInt_FromLong( + enum_value_descriptor->number())); + if (value_number == NULL) { + return -1; + } + if (PyObject_SetAttrString( + cls, enum_value_descriptor->name().c_str(), value_number) == -1) { + return -1; + } + } + } + + // For each extension set cls.<extension name> = <extension descriptor>. + // + // Extension descriptors come from + // <message descriptor>.extensions_by_name[name] + // which was defined previously. + for (int i = 0; i < message_descriptor->extension_count(); ++i) { + const google::protobuf::FieldDescriptor* field = message_descriptor->extension(i); + ScopedPyObjectPtr extension_field(PyFieldDescriptor_FromDescriptor(field)); + if (extension_field == NULL) { + return -1; + } + + // Add the extension field to the message class. + if (PyObject_SetAttrString( + cls, field->name().c_str(), extension_field) == -1) { + return -1; + } + + // For each extension set cls.<extension name>_FIELD_NUMBER = <number>. + if (!AddFieldNumberToClass(cls, field)) { + return -1; + } + } + + return 0; +} + +static PyObject* New(PyTypeObject* type, + PyObject* args, PyObject* kwargs) { + static char *kwlist[] = {"name", "bases", "dict", 0}; + PyObject *bases, *dict; + const char* name; + + // Check arguments: (name, bases, dict) + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "sO!O!:type", kwlist, + &name, + &PyTuple_Type, &bases, + &PyDict_Type, &dict)) { + return NULL; + } + + // Check bases: only (), or (message.Message,) are allowed + if (!(PyTuple_GET_SIZE(bases) == 0 || + (PyTuple_GET_SIZE(bases) == 1 && + PyTuple_GET_ITEM(bases, 0) == PythonMessage_class))) { + PyErr_SetString(PyExc_TypeError, + "A Message class can only inherit from Message"); + return NULL; + } + + // Check dict['DESCRIPTOR'] + PyObject* descriptor = PyDict_GetItem(dict, kDESCRIPTOR); + if (descriptor == NULL) { + PyErr_SetString(PyExc_TypeError, "Message class has no DESCRIPTOR"); + return NULL; + } + if (!PyObject_TypeCheck(descriptor, &PyMessageDescriptor_Type)) { + PyErr_Format(PyExc_TypeError, "Expected a message Descriptor, got %s", + descriptor->ob_type->tp_name); + return NULL; + } + + // Build the arguments to the base metaclass. + // We change the __bases__ classes. + ScopedPyObjectPtr new_args(Py_BuildValue( + "s(OO)O", name, &CMessage_Type, PythonMessage_class, dict)); + if (new_args == NULL) { + return NULL; + } + // Call the base metaclass. + ScopedPyObjectPtr result(PyType_Type.tp_new(type, new_args, NULL)); + if (result == NULL) { + return NULL; + } + PyMessageMeta* newtype = reinterpret_cast<PyMessageMeta*>(result.get()); + + // Insert the empty weakref into the base classes. + if (InsertEmptyWeakref( + reinterpret_cast<PyTypeObject*>(PythonMessage_class)) < 0 || + InsertEmptyWeakref(&CMessage_Type) < 0) { + return NULL; + } + + // Cache the descriptor, both as Python object and as C++ pointer. + const Descriptor* message_descriptor = + PyMessageDescriptor_AsDescriptor(descriptor); + if (message_descriptor == NULL) { + return NULL; + } + Py_INCREF(descriptor); + newtype->py_message_descriptor = descriptor; + newtype->message_descriptor = message_descriptor; + + // Continue with type initialization: add other descriptors, enum values... + if (AddDescriptors(result, descriptor) < 0) { + return NULL; + } + return result.release(); +} + +static void Dealloc(PyMessageMeta *self) { + Py_DECREF(self->py_message_descriptor); + Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self)); +} + +static PyObject* GetDescriptor(PyMessageMeta *self, void *closure) { + Py_INCREF(self->py_message_descriptor); + return self->py_message_descriptor; +} + + +// This function inserts and empty weakref at the end of the list of +// subclasses for the main protocol buffer Message class. +// +// This eliminates a O(n^2) behaviour in the internal add_subclass +// routine. +static int InsertEmptyWeakref(PyTypeObject *base_type) { +#if PY_MAJOR_VERSION >= 3 + // Python 3.4 has already included the fix for the issue that this + // hack addresses. For further background and the fix please see + // https://bugs.python.org/issue17936. + return 0; +#else + PyObject *subclasses = base_type->tp_subclasses; + if (subclasses && PyList_CheckExact(subclasses)) { + return PyList_Append(subclasses, kEmptyWeakref); + } + return 0; +#endif // PY_MAJOR_VERSION >= 3 +} + +} // namespace message_meta + +PyTypeObject PyMessageMeta_Type { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".MessageMeta", // tp_name + sizeof(PyMessageMeta), // tp_basicsize + 0, // tp_itemsize + (destructor)message_meta::Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, // tp_flags + "The metaclass of ProtocolMessages", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + 0, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init + 0, // tp_alloc + message_meta::New, // tp_new +}; + +static const Descriptor* GetMessageDescriptor(PyTypeObject* cls) { + if (!PyObject_TypeCheck(cls, &PyMessageMeta_Type)) { + PyErr_Format(PyExc_TypeError, "Class %s is not a Message", cls->tp_name); + return NULL; + } + return reinterpret_cast<PyMessageMeta*>(cls)->message_descriptor; +} + // Forward declarations namespace cmessage { -static const FieldDescriptor* GetFieldDescriptor( - CMessage* self, PyObject* name); -static const Descriptor* GetMessageDescriptor(PyTypeObject* cls); -static string GetMessageName(CMessage* self); int InternalReleaseFieldByDescriptor( CMessage* self, const FieldDescriptor* field_descriptor, @@ -180,7 +477,7 @@ int ForEachCompositeField(CMessage* self, Visitor visitor) { if (self->composite_fields) { // Never use self->message in this function, it may be already freed. const Descriptor* message_descriptor = - cmessage::GetMessageDescriptor(Py_TYPE(self)); + GetMessageDescriptor(Py_TYPE(self)); while (PyDict_Next(self->composite_fields, &pos, &key, &field)) { Py_ssize_t key_str_size; char *key_str_data; @@ -213,8 +510,6 @@ int ForEachCompositeField(CMessage* self, Visitor visitor) { // --------------------------------------------------------------------- -static DynamicMessageFactory* message_factory; - // Constants used for integer type range checking. PyObject* kPythonZero; PyObject* kint32min_py; @@ -224,17 +519,13 @@ PyObject* kint64min_py; PyObject* kint64max_py; PyObject* kuint64max_py; -PyObject* EnumTypeWrapper_class; PyObject* EncodeError_class; PyObject* DecodeError_class; PyObject* PickleError_class; // Constant PyString values used for GetAttr/GetItem. -static PyObject* kDESCRIPTOR; static PyObject* k_cdescriptor; static PyObject* kfull_name; -static PyObject* k_extensions_by_name; -static PyObject* k_extensions_by_number; /* Is 64bit */ void FormatTypeError(PyObject* arg, char* expected_types) { @@ -432,10 +723,6 @@ bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor, namespace cmessage { -DynamicMessageFactory* GetMessageFactory() { - return message_factory; -} - static int MaybeReleaseOverlappingOneofField( CMessage* cmessage, const FieldDescriptor* field) { @@ -486,7 +773,7 @@ static Message* GetMutableMessage( return NULL; } return reflection->MutableMessage( - parent_message, parent_field, message_factory); + parent_message, parent_field, GetDescriptorPool()->message_factory); } struct FixupMessageReference : public ChildVisitor { @@ -527,8 +814,9 @@ int AssureWritable(CMessage* self) { // If parent is NULL but we are trying to modify a read-only message, this // is a reference to a constant default instance that needs to be replaced // with a mutable top-level message. - const Message* prototype = message_factory->GetPrototype( - self->message->GetDescriptor()); + const Message* prototype = + GetDescriptorPool()->message_factory->GetPrototype( + self->message->GetDescriptor()); self->message = prototype->New(); self->owner.reset(self->message); // Cascade the new owner to eventual children: even if this message is @@ -567,23 +855,6 @@ int AssureWritable(CMessage* self) { // --- Globals: -// Retrieve the C++ Descriptor of a message class. -// On error, returns NULL with an exception set. -static const Descriptor* GetMessageDescriptor(PyTypeObject* cls) { - ScopedPyObjectPtr descriptor(PyObject_GetAttr( - reinterpret_cast<PyObject*>(cls), kDESCRIPTOR)); - if (descriptor == NULL) { - PyErr_SetString(PyExc_TypeError, "Message class has no DESCRIPTOR"); - return NULL; - } - if (!PyObject_TypeCheck(descriptor, &PyMessageDescriptor_Type)) { - PyErr_Format(PyExc_TypeError, "Expected a message Descriptor, got %s", - descriptor->ob_type->tp_name); - return NULL; - } - return PyMessageDescriptor_AsDescriptor(descriptor); -} - // Retrieve a C++ FieldDescriptor for a message attribute. // The C++ message must be valid. // TODO(amauryfa): This function should stay internal, because exception @@ -846,9 +1117,9 @@ int InitAttributes(CMessage* self, PyObject* kwargs) { return -1; } } else { - if (repeated_scalar_container::Extend( + if (ScopedPyObjectPtr(repeated_scalar_container::Extend( reinterpret_cast<RepeatedScalarContainer*>(container.get()), - value) == + value)) == NULL) { return -1; } @@ -927,7 +1198,7 @@ static PyObject* New(PyTypeObject* type, return NULL; } const Message* default_message = - message_factory->GetPrototype(message_descriptor); + GetDescriptorPool()->message_factory->GetPrototype(message_descriptor); if (default_message == NULL) { PyErr_SetString(PyExc_TypeError, message_descriptor->full_name().c_str()); return NULL; @@ -1257,6 +1528,7 @@ int SetOwner(CMessage* self, const shared_ptr<Message>& new_owner) { Message* ReleaseMessage(CMessage* self, const Descriptor* descriptor, const FieldDescriptor* field_descriptor) { + MessageFactory* message_factory = GetDescriptorPool()->message_factory; Message* released_message = self->message->GetReflection()->ReleaseMessage( self->message, field_descriptor, message_factory); // ReleaseMessage will return NULL which differs from @@ -1492,34 +1764,35 @@ static PyObject* SerializePartialToString(CMessage* self) { // appropriate. class PythonFieldValuePrinter : public TextFormat::FieldValuePrinter { public: - PythonFieldValuePrinter() : float_holder_(PyFloat_FromDouble(0)) {} - // Python has some differences from C++ when printing floating point numbers. // // 1) Trailing .0 is always printed. - // 2) Outputted is rounded to 12 digits. + // 2) (Python2) Output is rounded to 12 digits. + // 3) (Python3) The full precision of the double is preserved (and Python uses + // David M. Gay's dtoa(), when the C++ code uses SimpleDtoa. There are some + // differences, but they rarely happen) // // We override floating point printing with the C-API function for printing // Python floats to ensure consistency. string PrintFloat(float value) const { return PrintDouble(value); } string PrintDouble(double value) const { - reinterpret_cast<PyFloatObject*>(float_holder_.get())->ob_fval = value; - ScopedPyObjectPtr s(PyObject_Str(float_holder_.get())); - if (s == NULL) return string(); + // Same as float.__str__() + char* buf = PyOS_double_to_string( + value, #if PY_MAJOR_VERSION < 3 - char *cstr = PyBytes_AS_STRING(static_cast<PyObject*>(s)); + 'g', PyFloat_STR_PRECISION, // Output is rounded to 12 digits. #else - char *cstr = PyUnicode_AsUTF8(s); + 'r', 0, #endif - return string(cstr); + Py_DTSF_ADD_DOT_0, // Trailing .0 is always printed. + NULL); + if (!buf) { + return string(); + } + string result(buf); + PyMem_Free(buf); + return result; } - - private: - // Holder for a python float object which we use to allow us to use - // the Python API for printing doubles. We initialize once and then - // directly modify it for every float printed to save on allocations - // and refcounting. - ScopedPyObjectPtr float_holder_; }; static PyObject* ToStr(CMessage* self) { @@ -1590,7 +1863,7 @@ static PyObject* CopyFrom(CMessage* self, PyObject* arg) { // CopyFrom on the message will not clean up self->composite_fields, // which can leave us in an inconsistent state, so clear it out here. - Clear(self); + (void)ScopedPyObjectPtr(Clear(self)); self->message->CopyFrom(*other_message->message); @@ -1607,7 +1880,8 @@ static PyObject* MergeFromString(CMessage* self, PyObject* arg) { AssureWritable(self); io::CodedInputStream input( reinterpret_cast<const uint8*>(data), data_length); - input.SetExtensionRegistry(GetDescriptorPool()->pool, message_factory); + input.SetExtensionRegistry(GetDescriptorPool()->pool, + GetDescriptorPool()->message_factory); bool success = self->message->MergePartialFromCodedStream(&input); if (success) { return PyInt_FromLong(input.CurrentPosition()); @@ -1618,7 +1892,7 @@ static PyObject* MergeFromString(CMessage* self, PyObject* arg) { } static PyObject* ParseFromString(CMessage* self, PyObject* arg) { - if (Clear(self) == NULL) { + if (ScopedPyObjectPtr(Clear(self)) == NULL) { return NULL; } return MergeFromString(self, arg); @@ -1790,6 +2064,7 @@ static PyObject* ListFields(CMessage* self) { // Steals reference to 'extension' PyTuple_SET_ITEM(t.get(), 1, extension); } else { + // Normal field const string& field_name = fields[i]->name(); ScopedPyObjectPtr py_field_name(PyString_FromStringAndSize( field_name.c_str(), field_name.length())); @@ -1841,28 +2116,34 @@ PyObject* FindInitializationErrors(CMessage* self) { } static PyObject* RichCompare(CMessage* self, PyObject* other, int opid) { - if (!PyObject_TypeCheck(other, &CMessage_Type)) { - if (opid == Py_EQ) { - Py_RETURN_FALSE; - } else if (opid == Py_NE) { - Py_RETURN_TRUE; - } - } - if (opid == Py_EQ || opid == Py_NE) { - ScopedPyObjectPtr self_fields(ListFields(self)); - if (!self_fields) { - return NULL; - } - ScopedPyObjectPtr other_fields(ListFields( - reinterpret_cast<CMessage*>(other))); - if (!other_fields) { - return NULL; - } - return PyObject_RichCompare(self_fields, other_fields, opid); - } else { + // Only equality comparisons are implemented. + if (opid != Py_EQ && opid != Py_NE) { Py_INCREF(Py_NotImplemented); return Py_NotImplemented; } + bool equals = true; + // If other is not a message, it cannot be equal. + if (!PyObject_TypeCheck(other, &CMessage_Type)) { + equals = false; + } + const google::protobuf::Message* other_message = + reinterpret_cast<CMessage*>(other)->message; + // If messages don't have the same descriptors, they are not equal. + if (equals && + self->message->GetDescriptor() != other_message->GetDescriptor()) { + equals = false; + } + // Check the message contents. + if (equals && !google::protobuf::util::MessageDifferencer::Equals( + *self->message, + *reinterpret_cast<CMessage*>(other)->message)) { + equals = false; + } + if (equals ^ (opid == Py_EQ)) { + Py_RETURN_FALSE; + } else { + Py_RETURN_TRUE; + } } PyObject* InternalGetScalar(const Message* message, @@ -1950,7 +2231,7 @@ PyObject* InternalGetSubMessage( CMessage* self, const FieldDescriptor* field_descriptor) { const Reflection* reflection = self->message->GetReflection(); const Message& sub_message = reflection->GetMessage( - *self->message, field_descriptor, message_factory); + *self->message, field_descriptor, GetDescriptorPool()->message_factory); PyObject *message_class = cdescriptor_pool::GetMessageClass( GetDescriptorPool(), field_descriptor->message_type()); @@ -2085,125 +2366,6 @@ PyObject* FromString(PyTypeObject* cls, PyObject* serialized) { return py_cmsg; } -// Add the number of a field descriptor to the containing message class. -// Equivalent to: -// _cls.<field>_FIELD_NUMBER = <number> -static bool AddFieldNumberToClass( - PyObject* cls, const FieldDescriptor* field_descriptor) { - string constant_name = field_descriptor->name() + "_FIELD_NUMBER"; - UpperString(&constant_name); - ScopedPyObjectPtr attr_name(PyString_FromStringAndSize( - constant_name.c_str(), constant_name.size())); - if (attr_name == NULL) { - return false; - } - ScopedPyObjectPtr number(PyInt_FromLong(field_descriptor->number())); - if (number == NULL) { - return false; - } - if (PyObject_SetAttr(cls, attr_name, number) == -1) { - return false; - } - return true; -} - - -// Finalize the creation of the Message class. -// Called from its metaclass: GeneratedProtocolMessageType.__init__(). -static PyObject* AddDescriptors(PyObject* cls, PyObject* descriptor) { - const Descriptor* message_descriptor = - cdescriptor_pool::RegisterMessageClass( - GetDescriptorPool(), cls, descriptor); - if (message_descriptor == NULL) { - return NULL; - } - - // If there are extension_ranges, the message is "extendable", and extension - // classes will register themselves in this class. - if (message_descriptor->extension_range_count() > 0) { - ScopedPyObjectPtr by_name(PyDict_New()); - if (PyObject_SetAttr(cls, k_extensions_by_name, by_name) < 0) { - return NULL; - } - ScopedPyObjectPtr by_number(PyDict_New()); - if (PyObject_SetAttr(cls, k_extensions_by_number, by_number) < 0) { - return NULL; - } - } - - // For each field set: cls.<field>_FIELD_NUMBER = <number> - for (int i = 0; i < message_descriptor->field_count(); ++i) { - if (!AddFieldNumberToClass(cls, message_descriptor->field(i))) { - return NULL; - } - } - - // For each enum set cls.<enum name> = EnumTypeWrapper(<enum descriptor>). - // - // The enum descriptor we get from - // <messagedescriptor>.enum_types_by_name[name] - // which was built previously. - for (int i = 0; i < message_descriptor->enum_type_count(); ++i) { - const EnumDescriptor* enum_descriptor = message_descriptor->enum_type(i); - ScopedPyObjectPtr enum_type( - PyEnumDescriptor_FromDescriptor(enum_descriptor)); - if (enum_type == NULL) { - return NULL; - } - // Add wrapped enum type to message class. - ScopedPyObjectPtr wrapped(PyObject_CallFunctionObjArgs( - EnumTypeWrapper_class, enum_type.get(), NULL)); - if (wrapped == NULL) { - return NULL; - } - if (PyObject_SetAttrString( - cls, enum_descriptor->name().c_str(), wrapped) == -1) { - return NULL; - } - - // For each enum value add cls.<name> = <number> - for (int j = 0; j < enum_descriptor->value_count(); ++j) { - const EnumValueDescriptor* enum_value_descriptor = - enum_descriptor->value(j); - ScopedPyObjectPtr value_number(PyInt_FromLong( - enum_value_descriptor->number())); - if (value_number == NULL) { - return NULL; - } - if (PyObject_SetAttrString( - cls, enum_value_descriptor->name().c_str(), value_number) == -1) { - return NULL; - } - } - } - - // For each extension set cls.<extension name> = <extension descriptor>. - // - // Extension descriptors come from - // <message descriptor>.extensions_by_name[name] - // which was defined previously. - for (int i = 0; i < message_descriptor->extension_count(); ++i) { - const google::protobuf::FieldDescriptor* field = message_descriptor->extension(i); - ScopedPyObjectPtr extension_field(PyFieldDescriptor_FromDescriptor(field)); - if (extension_field == NULL) { - return NULL; - } - - // Add the extension field to the message class. - if (PyObject_SetAttrString( - cls, field->name().c_str(), extension_field) == -1) { - return NULL; - } - - // For each extension set cls.<extension name>_FIELD_NUMBER = <number>. - if (!AddFieldNumberToClass(cls, field)) { - return NULL; - } - } - - Py_RETURN_NONE; -} - PyObject* DeepCopy(CMessage* self, PyObject* arg) { PyObject* clone = PyObject_CallObject( reinterpret_cast<PyObject*>(Py_TYPE(self)), NULL); @@ -2214,8 +2376,9 @@ PyObject* DeepCopy(CMessage* self, PyObject* arg) { Py_DECREF(clone); return NULL; } - if (MergeFrom(reinterpret_cast<CMessage*>(clone), - reinterpret_cast<PyObject*>(self)) == NULL) { + if (ScopedPyObjectPtr(MergeFrom( + reinterpret_cast<CMessage*>(clone), + reinterpret_cast<PyObject*>(self))) == NULL) { Py_DECREF(clone); return NULL; } @@ -2281,7 +2444,7 @@ PyObject* SetState(CMessage* self, PyObject* state) { if (serialized == NULL) { return NULL; } - if (ParseFromString(self, serialized) == NULL) { + if (ScopedPyObjectPtr(ParseFromString(self, serialized)) == NULL) { return NULL; } Py_RETURN_NONE; @@ -2314,8 +2477,6 @@ static PyMethodDef Methods[] = { "Inputs picklable representation of the message." }, { "__unicode__", (PyCFunction)ToUnicode, METH_NOARGS, "Outputs a unicode representation of the message." }, - { "AddDescriptors", (PyCFunction)AddDescriptors, METH_O | METH_CLASS, - "Adds field descriptors to the class" }, { "ByteSize", (PyCFunction)ByteSize, METH_NOARGS, "Returns the size of the message in bytes." }, { "Clear", (PyCFunction)Clear, METH_NOARGS, @@ -2441,6 +2602,9 @@ PyObject* GetAttr(CMessage* self, PyObject* name) { if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { PyObject* sub_message = InternalGetSubMessage(self, field_descriptor); + if (sub_message == NULL) { + return NULL; + } if (!SetCompositeField(self, name, sub_message)) { Py_DECREF(sub_message); return NULL; @@ -2484,7 +2648,7 @@ int SetAttr(CMessage* self, PyObject* name, PyObject* value) { } // namespace cmessage PyTypeObject CMessage_Type = { - PyVarObject_HEAD_INIT(&PyType_Type, 0) + PyVarObject_HEAD_INIT(&PyMessageMeta_Type, 0) FULL_MODULE_NAME ".CMessage", // tp_name sizeof(CMessage), // tp_basicsize 0, // tp_itemsize @@ -2497,7 +2661,7 @@ PyTypeObject CMessage_Type = { 0, // tp_as_number 0, // tp_as_sequence 0, // tp_as_mapping - 0, // tp_hash + PyObject_HashNotImplemented, // tp_hash 0, // tp_call (reprfunc)cmessage::ToStr, // tp_str (getattrofunc)cmessage::GetAttr, // tp_getattro @@ -2580,8 +2744,9 @@ void InitGlobals() { k_extensions_by_name = PyString_FromString("_extensions_by_name"); k_extensions_by_number = PyString_FromString("_extensions_by_number"); - message_factory = new DynamicMessageFactory(); - message_factory->SetDelegateToGeneratedFactory(true); + PyObject *dummy_obj = PySet_New(NULL); + kEmptyWeakref = PyWeakref_NewRef(dummy_obj, NULL); + Py_DECREF(dummy_obj); } bool InitProto2MessageModule(PyObject *m) { @@ -2598,7 +2763,13 @@ bool InitProto2MessageModule(PyObject *m) { // Initialize constants defined in this file. InitGlobals(); - CMessage_Type.tp_hash = PyObject_HashNotImplemented; + PyMessageMeta_Type.tp_base = &PyType_Type; + if (PyType_Ready(&PyMessageMeta_Type) < 0) { + return false; + } + PyModule_AddObject(m, "MessageMeta", + reinterpret_cast<PyObject*>(&PyMessageMeta_Type)); + if (PyType_Ready(&CMessage_Type) < 0) { return false; } @@ -2628,86 +2799,106 @@ bool InitProto2MessageModule(PyObject *m) { PyModule_AddObject(m, "Message", reinterpret_cast<PyObject*>(&CMessage_Type)); - RepeatedScalarContainer_Type.tp_hash = - PyObject_HashNotImplemented; - if (PyType_Ready(&RepeatedScalarContainer_Type) < 0) { - return false; - } + // Initialize Repeated container types. + { + if (PyType_Ready(&RepeatedScalarContainer_Type) < 0) { + return false; + } - PyModule_AddObject(m, "RepeatedScalarContainer", - reinterpret_cast<PyObject*>( - &RepeatedScalarContainer_Type)); + PyModule_AddObject(m, "RepeatedScalarContainer", + reinterpret_cast<PyObject*>( + &RepeatedScalarContainer_Type)); - RepeatedCompositeContainer_Type.tp_hash = PyObject_HashNotImplemented; - if (PyType_Ready(&RepeatedCompositeContainer_Type) < 0) { - return false; - } + if (PyType_Ready(&RepeatedCompositeContainer_Type) < 0) { + return false; + } - PyModule_AddObject( - m, "RepeatedCompositeContainer", - reinterpret_cast<PyObject*>( - &RepeatedCompositeContainer_Type)); - - // ScalarMapContainer_Type derives from our MutableMapping type. - PyObject* containers = - PyImport_ImportModule("google.protobuf.internal.containers"); - if (containers == NULL) { - return false; + PyModule_AddObject( + m, "RepeatedCompositeContainer", + reinterpret_cast<PyObject*>( + &RepeatedCompositeContainer_Type)); + + // Register them as collections.Sequence + ScopedPyObjectPtr collections(PyImport_ImportModule("collections")); + if (collections == NULL) { + return false; + } + ScopedPyObjectPtr mutable_sequence(PyObject_GetAttrString( + collections, "MutableSequence")); + if (mutable_sequence == NULL) { + return false; + } + if (ScopedPyObjectPtr(PyObject_CallMethod(mutable_sequence, "register", "O", + &RepeatedScalarContainer_Type)) + == NULL) { + return false; + } + if (ScopedPyObjectPtr(PyObject_CallMethod(mutable_sequence, "register", "O", + &RepeatedCompositeContainer_Type)) + == NULL) { + return false; + } } - PyObject* mutable_mapping = - PyObject_GetAttrString(containers, "MutableMapping"); - Py_DECREF(containers); + // Initialize Map container types. + { + // ScalarMapContainer_Type derives from our MutableMapping type. + ScopedPyObjectPtr containers(PyImport_ImportModule( + "google.protobuf.internal.containers")); + if (containers == NULL) { + return false; + } - if (mutable_mapping == NULL) { - return false; - } + ScopedPyObjectPtr mutable_mapping( + PyObject_GetAttrString(containers, "MutableMapping")); + if (mutable_mapping == NULL) { + return false; + } - if (!PyObject_TypeCheck(mutable_mapping, &PyType_Type)) { - Py_DECREF(mutable_mapping); - return false; - } + if (!PyObject_TypeCheck(mutable_mapping, &PyType_Type)) { + return false; + } - ScalarMapContainer_Type.tp_base = - reinterpret_cast<PyTypeObject*>(mutable_mapping); + Py_INCREF(mutable_mapping); + ScalarMapContainer_Type.tp_base = + reinterpret_cast<PyTypeObject*>(mutable_mapping.get()); - if (PyType_Ready(&ScalarMapContainer_Type) < 0) { - return false; - } + if (PyType_Ready(&ScalarMapContainer_Type) < 0) { + return false; + } - PyModule_AddObject(m, "ScalarMapContainer", - reinterpret_cast<PyObject*>(&ScalarMapContainer_Type)); + PyModule_AddObject(m, "ScalarMapContainer", + reinterpret_cast<PyObject*>(&ScalarMapContainer_Type)); - if (PyType_Ready(&ScalarMapIterator_Type) < 0) { - return false; - } + if (PyType_Ready(&ScalarMapIterator_Type) < 0) { + return false; + } - PyModule_AddObject(m, "ScalarMapIterator", - reinterpret_cast<PyObject*>(&ScalarMapIterator_Type)); + PyModule_AddObject(m, "ScalarMapIterator", + reinterpret_cast<PyObject*>(&ScalarMapIterator_Type)); - Py_INCREF(mutable_mapping); - MessageMapContainer_Type.tp_base = - reinterpret_cast<PyTypeObject*>(mutable_mapping); + Py_INCREF(mutable_mapping); + MessageMapContainer_Type.tp_base = + reinterpret_cast<PyTypeObject*>(mutable_mapping.get()); - if (PyType_Ready(&MessageMapContainer_Type) < 0) { - return false; - } + if (PyType_Ready(&MessageMapContainer_Type) < 0) { + return false; + } - PyModule_AddObject(m, "MessageMapContainer", - reinterpret_cast<PyObject*>(&MessageMapContainer_Type)); + PyModule_AddObject(m, "MessageMapContainer", + reinterpret_cast<PyObject*>(&MessageMapContainer_Type)); - if (PyType_Ready(&MessageMapIterator_Type) < 0) { - return false; - } + if (PyType_Ready(&MessageMapIterator_Type) < 0) { + return false; + } - PyModule_AddObject(m, "MessageMapIterator", - reinterpret_cast<PyObject*>(&MessageMapIterator_Type)); + PyModule_AddObject(m, "MessageMapIterator", + reinterpret_cast<PyObject*>(&MessageMapIterator_Type)); + } - ExtensionDict_Type.tp_hash = PyObject_HashNotImplemented; if (PyType_Ready(&ExtensionDict_Type) < 0) { return false; } - PyModule_AddObject( m, "ExtensionDict", reinterpret_cast<PyObject*>(&ExtensionDict_Type)); @@ -2751,6 +2942,7 @@ bool InitProto2MessageModule(PyObject *m) { } EncodeError_class = PyObject_GetAttrString(message_module, "EncodeError"); DecodeError_class = PyObject_GetAttrString(message_module, "DecodeError"); + PythonMessage_class = PyObject_GetAttrString(message_module, "Message"); Py_DECREF(message_module); PyObject* pickle_module = PyImport_ImportModule("pickle"); |