diff options
Diffstat (limited to 'python/google/protobuf/pyext/message.cc')
-rw-r--r-- | python/google/protobuf/pyext/message.cc | 1033 |
1 files changed, 516 insertions, 517 deletions
diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc index 60ec9c1b..53736b9c 100644 --- a/python/google/protobuf/pyext/message.cc +++ b/python/google/protobuf/pyext/message.cc @@ -35,9 +35,6 @@ #include <map> #include <memory> -#ifndef _SHARED_PTR_H -#include <google/protobuf/stubs/shared_ptr.h> -#endif #include <string> #include <vector> #include <structmember.h> // A Python header file. @@ -52,6 +49,7 @@ #include <google/protobuf/stubs/common.h> #include <google/protobuf/stubs/logging.h> #include <google/protobuf/io/coded_stream.h> +#include <google/protobuf/io/zero_copy_stream_impl_lite.h> #include <google/protobuf/util/message_differencer.h> #include <google/protobuf/descriptor.h> #include <google/protobuf/message.h> @@ -63,11 +61,11 @@ #include <google/protobuf/pyext/repeated_composite_container.h> #include <google/protobuf/pyext/repeated_scalar_container.h> #include <google/protobuf/pyext/map_container.h> +#include <google/protobuf/pyext/message_factory.h> +#include <google/protobuf/pyext/safe_numerics.h> #include <google/protobuf/pyext/scoped_pyobject_ptr.h> -#include <google/protobuf/stubs/strutil.h> #if PY_MAJOR_VERSION >= 3 - #define PyInt_Check PyLong_Check #define PyInt_AsLong PyLong_AsLong #define PyInt_FromLong PyLong_FromLong #define PyInt_FromSize_t PyLong_FromSize_t @@ -91,42 +89,26 @@ namespace protobuf { namespace python { static PyObject* kDESCRIPTOR; -static PyObject* k_extensions_by_name; -static PyObject* k_extensions_by_number; PyObject* EnumTypeWrapper_class; static PyObject* PythonMessage_class; static PyObject* kEmptyWeakref; static PyObject* WKT_classes = NULL; -// Defines the Metaclass of all Message classes. -// It allows us to cache some C++ pointers in the class object itself, they are -// faster to extract than from the type's dictionary. - -struct PyMessageMeta { - // This is how CPython subclasses C structures: the base structure must be - // the first member of the object. - PyHeapTypeObject super; - - // C++ descriptor of this message. - const Descriptor* message_descriptor; - - // Owned reference, used to keep the pointer above alive. - PyObject* py_message_descriptor; - - // The Python DescriptorPool used to create the class. It is needed to resolve - // fields descriptors, including extensions fields; its C++ MessageFactory is - // used to instantiate submessages. - // This can be different from DESCRIPTOR.file.pool, in the case of a custom - // DescriptorPool which defines new extensions. - // We own the reference, because it's important to keep the descriptors and - // factory alive. - PyDescriptorPool* py_descriptor_pool; -}; - namespace message_meta { static int InsertEmptyWeakref(PyTypeObject* base); +namespace { +// Copied oveer from internal 'google/protobuf/stubs/strutil.h'. +inline void UpperString(string * s) { + string::iterator end = s->end(); + for (string::iterator i = s->begin(); i != end; ++i) { + // toupper() changes based on locale. We don't want this! + if ('a' <= *i && *i <= 'z') *i += 'A' - 'a'; + } +} +} + // Add the number of a field descriptor to the containing message class. // Equivalent to: // _cls.<field>_FIELD_NUMBER = <number> @@ -152,19 +134,6 @@ static bool AddFieldNumberToClass( // Finalize the creation of the Message class. static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) { - // If there are extension_ranges, the message is "extendable", and extension - // classes will register themselves in this class. - if (descriptor->extension_range_count() > 0) { - ScopedPyObjectPtr by_name(PyDict_New()); - if (PyObject_SetAttr(cls, k_extensions_by_name, by_name.get()) < 0) { - return -1; - } - ScopedPyObjectPtr by_number(PyDict_New()); - if (PyObject_SetAttr(cls, k_extensions_by_number, by_number.get()) < 0) { - return -1; - } - } - // For each field set: cls.<field>_FIELD_NUMBER = <number> for (int i = 0; i < descriptor->field_count(); ++i) { if (!AddFieldNumberToClass(cls, descriptor->field(i))) { @@ -173,10 +142,6 @@ static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) { } // For each enum set cls.<enum name> = EnumTypeWrapper(<enum descriptor>). - // - // The enum descriptor we get from - // <messagedescriptor>.enum_types_by_name[name] - // which was built previously. for (int i = 0; i < descriptor->enum_type_count(); ++i) { const EnumDescriptor* enum_descriptor = descriptor->enum_type(i); ScopedPyObjectPtr enum_type( @@ -273,6 +238,12 @@ static PyObject* New(PyTypeObject* type, return NULL; } + // Messages have no __dict__ + ScopedPyObjectPtr slots(PyTuple_New(0)); + if (PyDict_SetItemString(dict, "__slots__", slots.get()) < 0) { + return NULL; + } + // Build the arguments to the base metaclass. // We change the __bases__ classes. ScopedPyObjectPtr new_args; @@ -309,7 +280,7 @@ static PyObject* New(PyTypeObject* type, if (result == NULL) { return NULL; } - PyMessageMeta* newtype = reinterpret_cast<PyMessageMeta*>(result.get()); + CMessageClass* newtype = reinterpret_cast<CMessageClass*>(result.get()); // Insert the empty weakref into the base classes. if (InsertEmptyWeakref( @@ -329,16 +300,19 @@ static PyObject* New(PyTypeObject* type, newtype->message_descriptor = descriptor; // TODO(amauryfa): Don't always use the canonical pool of the descriptor, // use the MessageFactory optionally passed in the class dict. - newtype->py_descriptor_pool = GetDescriptorPool_FromPool( - descriptor->file()->pool()); - if (newtype->py_descriptor_pool == NULL) { + PyDescriptorPool* py_descriptor_pool = + GetDescriptorPool_FromPool(descriptor->file()->pool()); + if (py_descriptor_pool == NULL) { return NULL; } - Py_INCREF(newtype->py_descriptor_pool); + newtype->py_message_factory = py_descriptor_pool->py_message_factory; + Py_INCREF(newtype->py_message_factory); - // Add the message to the DescriptorPool. - if (cdescriptor_pool::RegisterMessageClass(newtype->py_descriptor_pool, - descriptor, result.get()) < 0) { + // Register the message in the MessageFactory. + // TODO(amauryfa): Move this call to MessageFactory.GetPrototype() when the + // MessageFactory is fully implemented in C++. + if (message_factory::RegisterMessageClass(newtype->py_message_factory, + descriptor, newtype) < 0) { return NULL; } @@ -349,9 +323,9 @@ static PyObject* New(PyTypeObject* type, return result.release(); } -static void Dealloc(PyMessageMeta *self) { - Py_DECREF(self->py_message_descriptor); - Py_DECREF(self->py_descriptor_pool); +static void Dealloc(CMessageClass *self) { + Py_XDECREF(self->py_message_descriptor); + Py_XDECREF(self->py_message_factory); Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self)); } @@ -376,12 +350,67 @@ static int InsertEmptyWeakref(PyTypeObject *base_type) { #endif // PY_MAJOR_VERSION >= 3 } +// The _extensions_by_name dictionary is built on every access. +// TODO(amauryfa): Migrate all users to pool.FindAllExtensions() +static PyObject* GetExtensionsByName(CMessageClass *self, void *closure) { + const PyDescriptorPool* pool = self->py_message_factory->pool; + + std::vector<const FieldDescriptor*> extensions; + pool->pool->FindAllExtensions(self->message_descriptor, &extensions); + + ScopedPyObjectPtr result(PyDict_New()); + for (int i = 0; i < extensions.size(); i++) { + ScopedPyObjectPtr extension( + PyFieldDescriptor_FromDescriptor(extensions[i])); + if (extension == NULL) { + return NULL; + } + if (PyDict_SetItemString(result.get(), extensions[i]->full_name().c_str(), + extension.get()) < 0) { + return NULL; + } + } + return result.release(); +} + +// The _extensions_by_number dictionary is built on every access. +// TODO(amauryfa): Migrate all users to pool.FindExtensionByNumber() +static PyObject* GetExtensionsByNumber(CMessageClass *self, void *closure) { + const PyDescriptorPool* pool = self->py_message_factory->pool; + + std::vector<const FieldDescriptor*> extensions; + pool->pool->FindAllExtensions(self->message_descriptor, &extensions); + + ScopedPyObjectPtr result(PyDict_New()); + for (int i = 0; i < extensions.size(); i++) { + ScopedPyObjectPtr extension( + PyFieldDescriptor_FromDescriptor(extensions[i])); + if (extension == NULL) { + return NULL; + } + ScopedPyObjectPtr number(PyInt_FromLong(extensions[i]->number())); + if (number == NULL) { + return NULL; + } + if (PyDict_SetItem(result.get(), number.get(), extension.get()) < 0) { + return NULL; + } + } + return result.release(); +} + +static PyGetSetDef Getters[] = { + {"_extensions_by_name", (getter)GetExtensionsByName, NULL}, + {"_extensions_by_number", (getter)GetExtensionsByNumber, NULL}, + {NULL} +}; + } // namespace message_meta -PyTypeObject PyMessageMeta_Type = { +PyTypeObject CMessageClass_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME ".MessageMeta", // tp_name - sizeof(PyMessageMeta), // tp_basicsize + sizeof(CMessageClass), // tp_basicsize 0, // tp_itemsize (destructor)message_meta::Dealloc, // tp_dealloc 0, // tp_print @@ -408,7 +437,7 @@ PyTypeObject PyMessageMeta_Type = { 0, // tp_iternext 0, // tp_methods 0, // tp_members - 0, // tp_getset + message_meta::Getters, // tp_getset 0, // tp_base 0, // tp_dict 0, // tp_descr_get @@ -419,16 +448,16 @@ PyTypeObject PyMessageMeta_Type = { message_meta::New, // tp_new }; -static PyMessageMeta* CheckMessageClass(PyTypeObject* cls) { - if (!PyObject_TypeCheck(cls, &PyMessageMeta_Type)) { +static CMessageClass* CheckMessageClass(PyTypeObject* cls) { + if (!PyObject_TypeCheck(cls, &CMessageClass_Type)) { PyErr_Format(PyExc_TypeError, "Class %s is not a Message", cls->tp_name); return NULL; } - return reinterpret_cast<PyMessageMeta*>(cls); + return reinterpret_cast<CMessageClass*>(cls); } static const Descriptor* GetMessageDescriptor(PyTypeObject* cls) { - PyMessageMeta* type = CheckMessageClass(cls); + CMessageClass* type = CheckMessageClass(cls); if (type == NULL) { return NULL; } @@ -544,23 +573,10 @@ int ForEachCompositeField(CMessage* self, Visitor visitor) { // --------------------------------------------------------------------- -// Constants used for integer type range checking. -PyObject* kPythonZero; -PyObject* kint32min_py; -PyObject* kint32max_py; -PyObject* kuint32max_py; -PyObject* kint64min_py; -PyObject* kint64max_py; -PyObject* kuint64max_py; - PyObject* EncodeError_class; PyObject* DecodeError_class; PyObject* PickleError_class; -// Constant PyString values used for GetAttr/GetItem. -static PyObject* k_cdescriptor; -static PyObject* kfull_name; - /* Is 64bit */ void FormatTypeError(PyObject* arg, char* expected_types) { PyObject* repr = PyObject_Repr(arg); @@ -574,68 +590,126 @@ void FormatTypeError(PyObject* arg, char* expected_types) { } } -template<class T> -bool CheckAndGetInteger( - PyObject* arg, T* value, PyObject* min, PyObject* max) { - bool is_long = PyLong_Check(arg); -#if PY_MAJOR_VERSION < 3 - if (!PyInt_Check(arg) && !is_long) { - FormatTypeError(arg, "int, long"); - return false; +void OutOfRangeError(PyObject* arg) { + PyObject *s = PyObject_Str(arg); + if (s) { + PyErr_Format(PyExc_ValueError, + "Value out of range: %s", + PyString_AsString(s)); + Py_DECREF(s); } - if (PyObject_Compare(min, arg) > 0 || PyObject_Compare(max, arg) < 0) { -#else - if (!is_long) { - FormatTypeError(arg, "int"); +} + +template<class RangeType, class ValueType> +bool VerifyIntegerCastAndRange(PyObject* arg, ValueType value) { + if (GOOGLE_PREDICT_FALSE(value == -1 && PyErr_Occurred())) { + if (PyErr_ExceptionMatches(PyExc_OverflowError)) { + // Replace it with the same ValueError as pure python protos instead of + // the default one. + PyErr_Clear(); + OutOfRangeError(arg); + } // Otherwise propagate existing error. return false; - } - if (PyObject_RichCompareBool(min, arg, Py_LE) != 1 || - PyObject_RichCompareBool(max, arg, Py_GE) != 1) { -#endif - if (!PyErr_Occurred()) { - PyObject *s = PyObject_Str(arg); - if (s) { - PyErr_Format(PyExc_ValueError, - "Value out of range: %s", - PyString_AsString(s)); - Py_DECREF(s); - } } - return false; - } + if (GOOGLE_PREDICT_FALSE(!IsValidNumericCast<RangeType>(value))) { + OutOfRangeError(arg); + return false; + } + return true; +} + +template<class T> +bool CheckAndGetInteger(PyObject* arg, T* value) { + // The fast path. #if PY_MAJOR_VERSION < 3 - if (!is_long) { - *value = static_cast<T>(PyInt_AsLong(arg)); - } else // NOLINT + // For the typical case, offer a fast path. + if (GOOGLE_PREDICT_TRUE(PyInt_Check(arg))) { + long int_result = PyInt_AsLong(arg); + if (GOOGLE_PREDICT_TRUE(IsValidNumericCast<T>(int_result))) { + *value = static_cast<T>(int_result); + return true; + } else { + OutOfRangeError(arg); + return false; + } + } #endif - { - if (min == kPythonZero) { - *value = static_cast<T>(PyLong_AsUnsignedLongLong(arg)); + // This effectively defines an integer as "an object that can be cast as + // an integer and can be used as an ordinal number". + // This definition includes everything that implements numbers.Integral + // and shouldn't cast the net too wide. + if (GOOGLE_PREDICT_FALSE(!PyIndex_Check(arg))) { + FormatTypeError(arg, "int, long"); + return false; + } + + // Now we have an integral number so we can safely use PyLong_ functions. + // We need to treat the signed and unsigned cases differently in case arg is + // holding a value above the maximum for signed longs. + if (std::numeric_limits<T>::min() == 0) { + // Unsigned case. + unsigned PY_LONG_LONG ulong_result; + if (PyLong_Check(arg)) { + ulong_result = PyLong_AsUnsignedLongLong(arg); } else { - *value = static_cast<T>(PyLong_AsLongLong(arg)); + // Unlike PyLong_AsLongLong, PyLong_AsUnsignedLongLong is very + // picky about the exact type. + PyObject* casted = PyNumber_Long(arg); + if (GOOGLE_PREDICT_FALSE(casted == nullptr)) { + // Propagate existing error. + return false; + } + ulong_result = PyLong_AsUnsignedLongLong(casted); + Py_DECREF(casted); + } + if (VerifyIntegerCastAndRange<T, unsigned PY_LONG_LONG>(arg, + ulong_result)) { + *value = static_cast<T>(ulong_result); + } else { + return false; + } + } else { + // Signed case. + PY_LONG_LONG long_result; + PyNumberMethods *nb; + if ((nb = arg->ob_type->tp_as_number) != NULL && nb->nb_int != NULL) { + // PyLong_AsLongLong requires it to be a long or to have an __int__() + // method. + long_result = PyLong_AsLongLong(arg); + } else { + // Valid subclasses of numbers.Integral should have a __long__() method + // so fall back to that. + PyObject* casted = PyNumber_Long(arg); + if (GOOGLE_PREDICT_FALSE(casted == nullptr)) { + // Propagate existing error. + return false; + } + long_result = PyLong_AsLongLong(casted); + Py_DECREF(casted); + } + if (VerifyIntegerCastAndRange<T, PY_LONG_LONG>(arg, long_result)) { + *value = static_cast<T>(long_result); + } else { + return false; } } + return true; } // These are referenced by repeated_scalar_container, and must // be explicitly instantiated. -template bool CheckAndGetInteger<int32>( - PyObject*, int32*, PyObject*, PyObject*); -template bool CheckAndGetInteger<int64>( - PyObject*, int64*, PyObject*, PyObject*); -template bool CheckAndGetInteger<uint32>( - PyObject*, uint32*, PyObject*, PyObject*); -template bool CheckAndGetInteger<uint64>( - PyObject*, uint64*, PyObject*, PyObject*); +template bool CheckAndGetInteger<int32>(PyObject*, int32*); +template bool CheckAndGetInteger<int64>(PyObject*, int64*); +template bool CheckAndGetInteger<uint32>(PyObject*, uint32*); +template bool CheckAndGetInteger<uint64>(PyObject*, uint64*); bool CheckAndGetDouble(PyObject* arg, double* value) { - if (!PyInt_Check(arg) && !PyLong_Check(arg) && - !PyFloat_Check(arg)) { + *value = PyFloat_AsDouble(arg); + if (GOOGLE_PREDICT_FALSE(*value == -1 && PyErr_Occurred())) { FormatTypeError(arg, "int, long, float"); return false; - } - *value = PyFloat_AsDouble(arg); + } return true; } @@ -649,11 +723,13 @@ bool CheckAndGetFloat(PyObject* arg, float* value) { } bool CheckAndGetBool(PyObject* arg, bool* value) { - if (!PyInt_Check(arg) && !PyBool_Check(arg) && !PyLong_Check(arg)) { + long long_value = PyInt_AsLong(arg); + if (long_value == -1 && PyErr_Occurred()) { FormatTypeError(arg, "int, long, bool"); return false; } - *value = static_cast<bool>(PyInt_AsLong(arg)); + *value = static_cast<bool>(long_value); + return true; } @@ -711,7 +787,7 @@ PyObject* CheckString(PyObject* arg, const FieldDescriptor* descriptor) { encoded_string = arg; // Already encoded. Py_INCREF(encoded_string); } else { - encoded_string = PyUnicode_AsEncodedObject(arg, "utf-8", NULL); + encoded_string = PyUnicode_AsEncodedString(arg, "utf-8", NULL); } } else { // In this case field type is "bytes". @@ -751,7 +827,8 @@ bool CheckAndSetString( return true; } -PyObject* ToStringObject(const FieldDescriptor* descriptor, string value) { +PyObject* ToStringObject(const FieldDescriptor* descriptor, + const string& value) { if (descriptor->type() != FieldDescriptor::TYPE_STRING) { return PyBytes_FromStringAndSize(value.c_str(), value.length()); } @@ -781,15 +858,9 @@ bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor, namespace cmessage { -PyDescriptorPool* GetDescriptorPoolForMessage(CMessage* message) { - // No need to check the type: the type of instances of CMessage is always - // an instance of PyMessageMeta. Let's prove it with a debug-only check. +PyMessageFactory* GetFactoryForMessage(CMessage* message) { GOOGLE_DCHECK(PyObject_TypeCheck(message, &CMessage_Type)); - return reinterpret_cast<PyMessageMeta*>(Py_TYPE(message))->py_descriptor_pool; -} - -MessageFactory* GetFactoryForMessage(CMessage* message) { - return GetDescriptorPoolForMessage(message)->message_factory; + return reinterpret_cast<CMessageClass*>(Py_TYPE(message))->py_message_factory; } static int MaybeReleaseOverlappingOneofField( @@ -842,7 +913,8 @@ static Message* GetMutableMessage( return NULL; } return reflection->MutableMessage( - parent_message, parent_field, GetFactoryForMessage(parent)); + parent_message, parent_field, + GetFactoryForMessage(parent)->message_factory); } struct FixupMessageReference : public ChildVisitor { @@ -990,28 +1062,17 @@ int InternalDeleteRepeatedField( int min, max; length = reflection->FieldSize(*message, field_descriptor); - if (PyInt_Check(slice) || PyLong_Check(slice)) { - from = to = PyLong_AsLong(slice); - if (from < 0) { - from = to = length + from; - } - step = 1; - min = max = from; - - // Range check. - if (from < 0 || from >= length) { - PyErr_Format(PyExc_IndexError, "list assignment index out of range"); - return -1; - } - } else if (PySlice_Check(slice)) { + if (PySlice_Check(slice)) { from = to = step = slice_length = 0; - PySlice_GetIndicesEx( #if PY_MAJOR_VERSION < 3 + PySlice_GetIndicesEx( reinterpret_cast<PySliceObject*>(slice), + length, &from, &to, &step, &slice_length); #else + PySlice_GetIndicesEx( slice, -#endif length, &from, &to, &step, &slice_length); +#endif if (from < to) { min = from; max = to - 1; @@ -1020,8 +1081,23 @@ int InternalDeleteRepeatedField( max = from; } } else { - PyErr_SetString(PyExc_TypeError, "list indices must be integers"); - return -1; + from = to = PyLong_AsLong(slice); + if (from == -1 && PyErr_Occurred()) { + PyErr_SetString(PyExc_TypeError, "list indices must be integers"); + return -1; + } + + if (from < 0) { + from = to = length + from; + } + step = 1; + min = max = from; + + // Range check. + if (from < 0 || from >= length) { + PyErr_Format(PyExc_IndexError, "list assignment index out of range"); + return -1; + } } Py_ssize_t i = from; @@ -1070,7 +1146,12 @@ int InternalDeleteRepeatedField( } // Initializes fields of a message. Used in constructors. -int InitAttributes(CMessage* self, PyObject* kwargs) { +int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) { + if (args != NULL && PyTuple_Size(args) != 0) { + PyErr_SetString(PyExc_TypeError, "No positional arguments allowed"); + return -1; + } + if (kwargs == NULL) { return 0; } @@ -1090,8 +1171,12 @@ int InitAttributes(CMessage* self, PyObject* kwargs) { PyString_AsString(name)); return -1; } + if (value == Py_None) { + // field=None is the same as no field at all. + continue; + } if (descriptor->is_map()) { - ScopedPyObjectPtr map(GetAttr(self, name)); + ScopedPyObjectPtr map(GetAttr(reinterpret_cast<PyObject*>(self), name)); const FieldDescriptor* value_descriptor = descriptor->message_type()->FindFieldByName("value"); if (value_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { @@ -1119,7 +1204,8 @@ int InitAttributes(CMessage* self, PyObject* kwargs) { } } } else if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { - ScopedPyObjectPtr container(GetAttr(self, name)); + ScopedPyObjectPtr container( + GetAttr(reinterpret_cast<PyObject*>(self), name)); if (container == NULL) { return -1; } @@ -1186,13 +1272,16 @@ int InitAttributes(CMessage* self, PyObject* kwargs) { } } } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - ScopedPyObjectPtr message(GetAttr(self, name)); + ScopedPyObjectPtr message( + GetAttr(reinterpret_cast<PyObject*>(self), name)); if (message == NULL) { return -1; } CMessage* cmessage = reinterpret_cast<CMessage*>(message.get()); if (PyDict_Check(value)) { - if (InitAttributes(cmessage, value) < 0) { + // Make the message exist even if the dict is empty. + AssureWritable(cmessage); + if (InitAttributes(cmessage, NULL, value) < 0) { return -1; } } else { @@ -1209,8 +1298,8 @@ int InitAttributes(CMessage* self, PyObject* kwargs) { return -1; } } - if (SetAttr(self, name, (new_val.get() == NULL) ? value : new_val.get()) < - 0) { + if (SetAttr(reinterpret_cast<PyObject*>(self), name, + (new_val.get() == NULL) ? value : new_val.get()) < 0) { return -1; } } @@ -1220,13 +1309,15 @@ int InitAttributes(CMessage* self, PyObject* kwargs) { // Allocates an incomplete Python Message: the caller must fill self->message, // self->owner and eventually self->parent. -CMessage* NewEmptyMessage(PyObject* type, const Descriptor *descriptor) { +CMessage* NewEmptyMessage(CMessageClass* type) { CMessage* self = reinterpret_cast<CMessage*>( - PyType_GenericAlloc(reinterpret_cast<PyTypeObject*>(type), 0)); + PyType_GenericAlloc(&type->super.ht_type, 0)); if (self == NULL) { return NULL; } + // Use "placement new" syntax to initialize the C++ object. + new (&self->owner) CMessage::OwnerRef(NULL); self->message = NULL; self->parent = NULL; self->parent_field_descriptor = NULL; @@ -1242,7 +1333,7 @@ CMessage* NewEmptyMessage(PyObject* type, const Descriptor *descriptor) { // Creates a new C++ message and takes ownership. static PyObject* New(PyTypeObject* cls, PyObject* unused_args, PyObject* unused_kwargs) { - PyMessageMeta* type = CheckMessageClass(cls); + CMessageClass* type = CheckMessageClass(cls); if (type == NULL) { return NULL; } @@ -1251,15 +1342,14 @@ static PyObject* New(PyTypeObject* cls, if (message_descriptor == NULL) { return NULL; } - const Message* default_message = type->py_descriptor_pool->message_factory + const Message* default_message = type->py_message_factory->message_factory ->GetPrototype(message_descriptor); if (default_message == NULL) { PyErr_SetString(PyExc_TypeError, message_descriptor->full_name().c_str()); return NULL; } - CMessage* self = NewEmptyMessage(reinterpret_cast<PyObject*>(type), - message_descriptor); + CMessage* self = NewEmptyMessage(type); if (self == NULL) { return NULL; } @@ -1271,12 +1361,7 @@ static PyObject* New(PyTypeObject* cls, // The __init__ method of Message classes. // It initializes fields from keywords passed to the constructor. static int Init(CMessage* self, PyObject* args, PyObject* kwargs) { - if (PyTuple_Size(args) != 0) { - PyErr_SetString(PyExc_TypeError, "No positional arguments allowed"); - return -1; - } - - return InitAttributes(self, kwargs); + return InitAttributes(self, args, kwargs); } // --------------------------------------------------------------------- @@ -1318,6 +1403,9 @@ struct ClearWeakReferences : public ChildVisitor { }; static void Dealloc(CMessage* self) { + if (self->weakreflist) { + PyObject_ClearWeakRefs(reinterpret_cast<PyObject*>(self)); + } // Null out all weak references from children to this message. GOOGLE_CHECK_EQ(0, ForEachCompositeField(self, ClearWeakReferences())); if (self->extensions) { @@ -1326,7 +1414,7 @@ static void Dealloc(CMessage* self) { Py_CLEAR(self->extensions); Py_CLEAR(self->composite_fields); - self->owner.reset(); + self->owner.~ThreadUnsafeSharedPtr<Message>(); Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self)); } @@ -1467,36 +1555,25 @@ PyObject* HasField(CMessage* self, PyObject* arg) { if (message->GetReflection()->HasField(*message, field_descriptor)) { Py_RETURN_TRUE; } - if (!message->GetReflection()->SupportsUnknownEnumValues() && - field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) { - // Special case: Python HasField() differs in semantics from C++ - // slightly: we return HasField('enum_field') == true if there is - // an unknown enum value present. To implement this we have to - // look in the UnknownFieldSet. - const UnknownFieldSet& unknown_field_set = - message->GetReflection()->GetUnknownFields(*message); - for (int i = 0; i < unknown_field_set.field_count(); ++i) { - if (unknown_field_set.field(i).number() == field_descriptor->number()) { - Py_RETURN_TRUE; - } - } - } + Py_RETURN_FALSE; } PyObject* ClearExtension(CMessage* self, PyObject* extension) { + const FieldDescriptor* descriptor = GetExtensionDescriptor(extension); + if (descriptor == NULL) { + return NULL; + } if (self->extensions != NULL) { - return extension_dict::ClearExtension(self->extensions, extension); - } else { - const FieldDescriptor* descriptor = GetExtensionDescriptor(extension); - if (descriptor == NULL) { - return NULL; - } - if (ScopedPyObjectPtr(ClearFieldByDescriptor(self, descriptor)) == NULL) { - return NULL; + PyObject* value = PyDict_GetItem(self->extensions->values, extension); + if (value != NULL) { + if (InternalReleaseFieldByDescriptor(self, descriptor, value) < 0) { + return NULL; + } + PyDict_DelItem(self->extensions->values, extension); } } - Py_RETURN_NONE; + return ClearFieldByDescriptor(self, descriptor); } PyObject* HasExtension(CMessage* self, PyObject* extension) { @@ -1539,9 +1616,10 @@ PyObject* HasExtension(CMessage* self, PyObject* extension) { // * Clear the weak references from the released container to the // parent. -struct SetOwnerVisitor : public ChildVisitor { +class SetOwnerVisitor : public ChildVisitor { + public: // new_owner must outlive this object. - explicit SetOwnerVisitor(const shared_ptr<Message>& new_owner) + explicit SetOwnerVisitor(const CMessage::OwnerRef& new_owner) : new_owner_(new_owner) {} int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) { @@ -1565,11 +1643,11 @@ struct SetOwnerVisitor : public ChildVisitor { } private: - const shared_ptr<Message>& new_owner_; + const CMessage::OwnerRef& new_owner_; }; // Change the owner of this CMessage and all its children, recursively. -int SetOwner(CMessage* self, const shared_ptr<Message>& new_owner) { +int SetOwner(CMessage* self, const CMessage::OwnerRef& new_owner) { self->owner = new_owner; if (ForEachCompositeField(self, SetOwnerVisitor(new_owner)) == -1) return -1; @@ -1582,7 +1660,7 @@ int SetOwner(CMessage* self, const shared_ptr<Message>& new_owner) { Message* ReleaseMessage(CMessage* self, const Descriptor* descriptor, const FieldDescriptor* field_descriptor) { - MessageFactory* message_factory = GetFactoryForMessage(self); + MessageFactory* message_factory = GetFactoryForMessage(self)->message_factory; Message* released_message = self->message->GetReflection()->ReleaseMessage( self->message, field_descriptor, message_factory); // ReleaseMessage will return NULL which differs from @@ -1602,7 +1680,7 @@ int ReleaseSubMessage(CMessage* self, const FieldDescriptor* field_descriptor, CMessage* child_cmessage) { // Release the Message - shared_ptr<Message> released_message(ReleaseMessage( + CMessage::OwnerRef released_message(ReleaseMessage( self, child_cmessage->message->GetDescriptor(), field_descriptor)); child_cmessage->message = released_message.get(); child_cmessage->owner.swap(released_message); @@ -1619,23 +1697,20 @@ struct ReleaseChild : public ChildVisitor { parent_(parent) {} int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) { - return repeated_composite_container::Release( - reinterpret_cast<RepeatedCompositeContainer*>(container)); + return repeated_composite_container::Release(container); } int VisitRepeatedScalarContainer(RepeatedScalarContainer* container) { - return repeated_scalar_container::Release( - reinterpret_cast<RepeatedScalarContainer*>(container)); + return repeated_scalar_container::Release(container); } int VisitMapContainer(MapContainer* container) { - return reinterpret_cast<MapContainer*>(container)->Release(); + return container->Release(); } int VisitCMessage(CMessage* cmessage, const FieldDescriptor* field_descriptor) { - return ReleaseSubMessage(parent_, field_descriptor, - reinterpret_cast<CMessage*>(cmessage)); + return ReleaseSubMessage(parent_, field_descriptor, cmessage); } CMessage* parent_; @@ -1653,12 +1728,13 @@ int InternalReleaseFieldByDescriptor( PyObject* ClearFieldByDescriptor( CMessage* self, - const FieldDescriptor* descriptor) { - if (!CheckFieldBelongsToMessage(descriptor, self->message)) { + const FieldDescriptor* field_descriptor) { + if (!CheckFieldBelongsToMessage(field_descriptor, self->message)) { return NULL; } AssureWritable(self); - self->message->GetReflection()->ClearField(self->message, descriptor); + Message* message = self->message; + message->GetReflection()->ClearField(message, field_descriptor); Py_RETURN_NONE; } @@ -1694,27 +1770,17 @@ PyObject* ClearField(CMessage* self, PyObject* arg) { arg = arg_in_oneof.get(); } - PyObject* composite_field = self->composite_fields ? - PyDict_GetItem(self->composite_fields, arg) : NULL; - - // Only release the field if there's a possibility that there are - // references to it. - if (composite_field != NULL) { - if (InternalReleaseFieldByDescriptor(self, field_descriptor, - composite_field) < 0) { - return NULL; + // Release the field if it exists in the dict of composite fields. + if (self->composite_fields) { + PyObject* value = PyDict_GetItem(self->composite_fields, arg); + if (value != NULL) { + if (InternalReleaseFieldByDescriptor(self, field_descriptor, value) < 0) { + return NULL; + } + PyDict_DelItem(self->composite_fields, arg); } - PyDict_DelItem(self->composite_fields, arg); - } - message->GetReflection()->ClearField(message, field_descriptor); - if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM && - !message->GetReflection()->SupportsUnknownEnumValues()) { - UnknownFieldSet* unknown_field_set = - message->GetReflection()->MutableUnknownFields(message); - unknown_field_set->DeleteByNumber(field_descriptor->number()); } - - Py_RETURN_NONE; + return ClearFieldByDescriptor(self, field_descriptor); } PyObject* Clear(CMessage* self) { @@ -1739,8 +1805,25 @@ static string GetMessageName(CMessage* self) { } } -static PyObject* SerializeToString(CMessage* self, PyObject* args) { - if (!self->message->IsInitialized()) { +static PyObject* InternalSerializeToString( + CMessage* self, PyObject* args, PyObject* kwargs, + bool require_initialized) { + // Parse the "deterministic" kwarg; defaults to False. + static char* kwlist[] = { "deterministic", 0 }; + PyObject* deterministic_obj = Py_None; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O", kwlist, + &deterministic_obj)) { + return NULL; + } + // Preemptively convert to a bool first, so we don't need to back out of + // allocating memory if this raises an exception. + // NOTE: This is unused later if deterministic == Py_None, but that's fine. + int deterministic = PyObject_IsTrue(deterministic_obj); + if (deterministic < 0) { + return NULL; + } + + if (require_initialized && !self->message->IsInitialized()) { ScopedPyObjectPtr errors(FindInitializationErrors(self)); if (errors == NULL) { return NULL; @@ -1778,24 +1861,36 @@ static PyObject* SerializeToString(CMessage* self, PyObject* args) { GetMessageName(self).c_str(), PyString_AsString(joined.get())); return NULL; } - int size = self->message->ByteSize(); - if (size <= 0) { + + // Ok, arguments parsed and errors checked, now encode to a string + const size_t size = self->message->ByteSizeLong(); + if (size == 0) { return PyBytes_FromString(""); } PyObject* result = PyBytes_FromStringAndSize(NULL, size); if (result == NULL) { return NULL; } - char* buffer = PyBytes_AS_STRING(result); - self->message->SerializeWithCachedSizesToArray( - reinterpret_cast<uint8*>(buffer)); + io::ArrayOutputStream out(PyBytes_AS_STRING(result), size); + io::CodedOutputStream coded_out(&out); + if (deterministic_obj != Py_None) { + coded_out.SetSerializationDeterministic(deterministic); + } + self->message->SerializeWithCachedSizes(&coded_out); + GOOGLE_CHECK(!coded_out.HadError()); return result; } -static PyObject* SerializePartialToString(CMessage* self) { - string contents; - self->message->SerializePartialToString(&contents); - return PyBytes_FromStringAndSize(contents.c_str(), contents.size()); +static PyObject* SerializeToString( + CMessage* self, PyObject* args, PyObject* kwargs) { + return InternalSerializeToString(self, args, kwargs, + /*require_initialized=*/true); +} + +static PyObject* SerializePartialToString( + CMessage* self, PyObject* args, PyObject* kwargs) { + return InternalSerializeToString(self, args, kwargs, + /*require_initialized=*/false); } // Formats proto fields for ascii dumps using python formatting functions where @@ -1851,8 +1946,12 @@ static PyObject* ToStr(CMessage* self) { PyObject* MergeFrom(CMessage* self, PyObject* arg) { CMessage* other_message; - if (!PyObject_TypeCheck(reinterpret_cast<PyObject *>(arg), &CMessage_Type)) { - PyErr_SetString(PyExc_TypeError, "Must be a message"); + if (!PyObject_TypeCheck(arg, &CMessage_Type)) { + PyErr_Format(PyExc_TypeError, + "Parameter to MergeFrom() must be instance of same class: " + "expected %s got %s.", + self->message->GetDescriptor()->full_name().c_str(), + Py_TYPE(arg)->tp_name); return NULL; } @@ -1860,8 +1959,8 @@ PyObject* MergeFrom(CMessage* self, PyObject* arg) { if (other_message->message->GetDescriptor() != self->message->GetDescriptor()) { PyErr_Format(PyExc_TypeError, - "Tried to merge from a message with a different type. " - "to: %s, from: %s", + "Parameter to MergeFrom() must be instance of same class: " + "expected %s got %s.", self->message->GetDescriptor()->full_name().c_str(), other_message->message->GetDescriptor()->full_name().c_str()); return NULL; @@ -1879,8 +1978,12 @@ PyObject* MergeFrom(CMessage* self, PyObject* arg) { static PyObject* CopyFrom(CMessage* self, PyObject* arg) { CMessage* other_message; - if (!PyObject_TypeCheck(reinterpret_cast<PyObject *>(arg), &CMessage_Type)) { - PyErr_SetString(PyExc_TypeError, "Must be a message"); + if (!PyObject_TypeCheck(arg, &CMessage_Type)) { + PyErr_Format(PyExc_TypeError, + "Parameter to CopyFrom() must be instance of same class: " + "expected %s got %s.", + self->message->GetDescriptor()->full_name().c_str(), + Py_TYPE(arg)->tp_name); return NULL; } @@ -1893,8 +1996,8 @@ static PyObject* CopyFrom(CMessage* self, PyObject* arg) { if (other_message->message->GetDescriptor() != self->message->GetDescriptor()) { PyErr_Format(PyExc_TypeError, - "Tried to copy from a message with a different type. " - "to: %s, from: %s", + "Parameter to CopyFrom() must be instance of same class: " + "expected %s got %s.", self->message->GetDescriptor()->full_name().c_str(), other_message->message->GetDescriptor()->full_name().c_str()); return NULL; @@ -1911,6 +2014,34 @@ static PyObject* CopyFrom(CMessage* self, PyObject* arg) { Py_RETURN_NONE; } +// Protobuf has a 64MB limit built in, this variable will override this. Please +// do not enable this unless you fully understand the implications: protobufs +// must all be kept in memory at the same time, so if they grow too big you may +// get OOM errors. The protobuf APIs do not provide any tools for processing +// protobufs in chunks. If you have protos this big you should break them up if +// it is at all convenient to do so. +#ifdef PROTOBUF_PYTHON_ALLOW_OVERSIZE_PROTOS +static bool allow_oversize_protos = true; +#else +static bool allow_oversize_protos = false; +#endif + +// Provide a method in the module to set allow_oversize_protos to a boolean +// value. This method returns the newly value of allow_oversize_protos. +PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg) { + if (!arg || !PyBool_Check(arg)) { + PyErr_SetString(PyExc_TypeError, + "Argument to SetAllowOversizeProtos must be boolean"); + return NULL; + } + allow_oversize_protos = PyObject_IsTrue(arg); + if (allow_oversize_protos) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} + static PyObject* MergeFromString(CMessage* self, PyObject* arg) { const void* data; Py_ssize_t data_length; @@ -1921,19 +2052,18 @@ static PyObject* MergeFromString(CMessage* self, PyObject* arg) { AssureWritable(self); io::CodedInputStream input( reinterpret_cast<const uint8*>(data), data_length); -#if PROTOBUF_PYTHON_ALLOW_OVERSIZE_PROTOS - // Protobuf has a 64MB limit built in, this code will override this. Please do - // not enable this unless you fully understand the implications: protobufs - // must all be kept in memory at the same time, so if they grow too big you - // may get OOM errors. The protobuf APIs do not provide any tools for - // processing protobufs in chunks. If you have protos this big you should - // break them up if it is at all convenient to do so. - input.SetTotalBytesLimit(INT_MAX, INT_MAX); -#endif // PROTOBUF_PYTHON_ALLOW_OVERSIZE_PROTOS - PyDescriptorPool* pool = GetDescriptorPoolForMessage(self); - input.SetExtensionRegistry(pool->pool, pool->message_factory); + if (allow_oversize_protos) { + input.SetTotalBytesLimit(INT_MAX, INT_MAX); + } + PyMessageFactory* factory = GetFactoryForMessage(self); + input.SetExtensionRegistry(factory->pool->pool, factory->message_factory); bool success = self->message->MergePartialFromCodedStream(&input); if (success) { + if (!input.ConsumedEntireMessage()) { + // TODO(jieluo): Raise error and return NULL instead. + // b/27494216 + PyErr_Warn(NULL, "Unexpected end-group tag: Not all data was converted"); + } return PyInt_FromLong(input.CurrentPosition()); } else { PyErr_Format(DecodeError_class, "Error parsing message"); @@ -1952,75 +2082,29 @@ static PyObject* ByteSize(CMessage* self, PyObject* args) { return PyLong_FromLong(self->message->ByteSize()); } -static PyObject* RegisterExtension(PyObject* cls, - PyObject* extension_handle) { +PyObject* RegisterExtension(PyObject* cls, PyObject* extension_handle) { const FieldDescriptor* descriptor = GetExtensionDescriptor(extension_handle); if (descriptor == NULL) { return NULL; } - - ScopedPyObjectPtr extensions_by_name( - PyObject_GetAttr(cls, k_extensions_by_name)); - if (extensions_by_name == NULL) { - PyErr_SetString(PyExc_TypeError, "no extensions_by_name on class"); + if (!PyObject_TypeCheck(cls, &CMessageClass_Type)) { + PyErr_Format(PyExc_TypeError, "Expected a message class, got %s", + cls->ob_type->tp_name); return NULL; } - ScopedPyObjectPtr full_name(PyObject_GetAttr(extension_handle, kfull_name)); - if (full_name == NULL) { + CMessageClass *message_class = reinterpret_cast<CMessageClass*>(cls); + if (message_class == NULL) { return NULL; } - // If the extension was already registered, check that it is the same. - PyObject* existing_extension = - PyDict_GetItem(extensions_by_name.get(), full_name.get()); - if (existing_extension != NULL) { - const FieldDescriptor* existing_extension_descriptor = - GetExtensionDescriptor(existing_extension); - if (existing_extension_descriptor != descriptor) { - PyErr_SetString(PyExc_ValueError, "Double registration of Extensions"); - return NULL; - } - // Nothing else to do. - Py_RETURN_NONE; - } - - if (PyDict_SetItem(extensions_by_name.get(), full_name.get(), - extension_handle) < 0) { - return NULL; - } - - // Also store a mapping from extension number to implementing class. - ScopedPyObjectPtr extensions_by_number( - PyObject_GetAttr(cls, k_extensions_by_number)); - if (extensions_by_number == NULL) { - PyErr_SetString(PyExc_TypeError, "no extensions_by_number on class"); - return NULL; - } - ScopedPyObjectPtr number(PyObject_GetAttrString(extension_handle, "number")); - if (number == NULL) { - return NULL; - } - if (PyDict_SetItem(extensions_by_number.get(), number.get(), - extension_handle) < 0) { + const FieldDescriptor* existing_extension = + message_class->py_message_factory->pool->pool->FindExtensionByNumber( + descriptor->containing_type(), descriptor->number()); + if (existing_extension != NULL && existing_extension != descriptor) { + PyErr_SetString(PyExc_ValueError, "Double registration of Extensions"); return NULL; } - - // Check if it's a message set - if (descriptor->is_extension() && - descriptor->containing_type()->options().message_set_wire_format() && - descriptor->type() == FieldDescriptor::TYPE_MESSAGE && - descriptor->label() == FieldDescriptor::LABEL_OPTIONAL) { - ScopedPyObjectPtr message_name(PyString_FromStringAndSize( - descriptor->message_type()->full_name().c_str(), - descriptor->message_type()->full_name().size())); - if (message_name == NULL) { - return NULL; - } - PyDict_SetItem(extensions_by_name.get(), message_name.get(), - extension_handle); - } - Py_RETURN_NONE; } @@ -2057,7 +2141,7 @@ static PyObject* WhichOneof(CMessage* self, PyObject* arg) { static PyObject* GetExtensionDict(CMessage* self, void *closure); static PyObject* ListFields(CMessage* self) { - vector<const FieldDescriptor*> fields; + std::vector<const FieldDescriptor*> fields; self->message->GetReflection()->ListFields(*self->message, &fields); // Normally, the list will be exactly the size of the fields. @@ -2087,8 +2171,8 @@ static PyObject* ListFields(CMessage* self) { // is no message class and we cannot retrieve the value. // TODO(amauryfa): consider building the class on the fly! if (fields[i]->message_type() != NULL && - cdescriptor_pool::GetMessageClass( - GetDescriptorPoolForMessage(self), + message_factory::GetMessageClass( + GetFactoryForMessage(self), fields[i]->message_type()) == NULL) { PyErr_Clear(); continue; @@ -2121,7 +2205,8 @@ static PyObject* ListFields(CMessage* self) { return NULL; } - PyObject* field_value = GetAttr(self, py_field_name.get()); + PyObject* field_value = + GetAttr(reinterpret_cast<PyObject*>(self), py_field_name.get()); if (field_value == NULL) { PyErr_SetObject(PyExc_ValueError, py_field_name.get()); return NULL; @@ -2132,13 +2217,23 @@ static PyObject* ListFields(CMessage* self) { PyList_SET_ITEM(all_fields.get(), actual_size, t.release()); ++actual_size; } - Py_SIZE(all_fields.get()) = actual_size; + if (static_cast<size_t>(actual_size) != fields.size() && + (PyList_SetSlice(all_fields.get(), actual_size, fields.size(), NULL) < + 0)) { + return NULL; + } return all_fields.release(); } +static PyObject* DiscardUnknownFields(CMessage* self) { + AssureWritable(self); + self->message->DiscardUnknownFields(); + Py_RETURN_NONE; +} + PyObject* FindInitializationErrors(CMessage* self) { Message* message = self->message; - vector<string> errors; + std::vector<string> errors; message->FindInitializationErrors(&errors); PyObject* error_list = PyList_New(errors.size()); @@ -2235,32 +2330,16 @@ PyObject* InternalGetScalar(const Message* message, break; } case FieldDescriptor::CPPTYPE_STRING: { - string value = reflection->GetString(*message, field_descriptor); + string scratch; + const string& value = + reflection->GetStringReference(*message, field_descriptor, &scratch); result = ToStringObject(field_descriptor, value); break; } case FieldDescriptor::CPPTYPE_ENUM: { - if (!message->GetReflection()->SupportsUnknownEnumValues() && - !message->GetReflection()->HasField(*message, field_descriptor)) { - // Look for the value in the unknown fields. - const UnknownFieldSet& unknown_field_set = - message->GetReflection()->GetUnknownFields(*message); - for (int i = 0; i < unknown_field_set.field_count(); ++i) { - if (unknown_field_set.field(i).number() == - field_descriptor->number() && - unknown_field_set.field(i).type() == - google::protobuf::UnknownField::TYPE_VARINT) { - result = PyInt_FromLong(unknown_field_set.field(i).varint()); - break; - } - } - } - - if (result == NULL) { - const EnumValueDescriptor* enum_value = - message->GetReflection()->GetEnum(*message, field_descriptor); - result = PyInt_FromLong(enum_value->number()); - } + const EnumValueDescriptor* enum_value = + message->GetReflection()->GetEnum(*message, field_descriptor); + result = PyInt_FromLong(enum_value->number()); break; } default: @@ -2275,18 +2354,19 @@ PyObject* InternalGetScalar(const Message* message, PyObject* InternalGetSubMessage( CMessage* self, const FieldDescriptor* field_descriptor) { const Reflection* reflection = self->message->GetReflection(); - PyDescriptorPool* pool = GetDescriptorPoolForMessage(self); + PyMessageFactory* factory = GetFactoryForMessage(self); const Message& sub_message = reflection->GetMessage( - *self->message, field_descriptor, pool->message_factory); + *self->message, field_descriptor, factory->message_factory); - PyObject *message_class = cdescriptor_pool::GetMessageClass( - pool, field_descriptor->message_type()); + CMessageClass* message_class = message_factory::GetOrCreateMessageClass( + factory, field_descriptor->message_type()); + ScopedPyObjectPtr message_class_handler( + reinterpret_cast<PyObject*>(message_class)); if (message_class == NULL) { return NULL; } - CMessage* cmsg = cmessage::NewEmptyMessage(message_class, - sub_message.GetDescriptor()); + CMessage* cmsg = cmessage::NewEmptyMessage(message_class); if (cmsg == NULL) { return NULL; } @@ -2471,7 +2551,10 @@ PyObject* Reduce(CMessage* self) { if (state == NULL) { return NULL; } - ScopedPyObjectPtr serialized(SerializePartialToString(self)); + string contents; + self->message->SerializePartialToString(&contents); + ScopedPyObjectPtr serialized( + PyBytes_FromStringAndSize(contents.c_str(), contents.size())); if (serialized == NULL) { return NULL; } @@ -2531,11 +2614,24 @@ static PyObject* GetExtensionDict(CMessage* self, void *closure) { return NULL; } +static PyObject* GetExtensionsByName(CMessage *self, void *closure) { + return message_meta::GetExtensionsByName( + reinterpret_cast<CMessageClass*>(Py_TYPE(self)), closure); +} + +static PyObject* GetExtensionsByNumber(CMessage *self, void *closure) { + return message_meta::GetExtensionsByNumber( + reinterpret_cast<CMessageClass*>(Py_TYPE(self)), closure); +} + static PyGetSetDef Getters[] = { {"Extensions", (getter)GetExtensionDict, NULL, "Extension dict"}, + {"_extensions_by_name", (getter)GetExtensionsByName, NULL}, + {"_extensions_by_number", (getter)GetExtensionsByNumber, NULL}, {NULL} }; + static PyMethodDef Methods[] = { { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS, "Makes a deep copy of the class." }, @@ -2555,6 +2651,8 @@ static PyMethodDef Methods[] = { "Clears a message field." }, { "CopyFrom", (PyCFunction)CopyFrom, METH_O, "Copies a protocol message into the current message." }, + { "DiscardUnknownFields", (PyCFunction)DiscardUnknownFields, METH_NOARGS, + "Discards the unknown fields." }, { "FindInitializationErrors", (PyCFunction)FindInitializationErrors, METH_NOARGS, "Finds unset required fields." }, @@ -2577,9 +2675,10 @@ static PyMethodDef Methods[] = { { "RegisterExtension", (PyCFunction)RegisterExtension, METH_O | METH_CLASS, "Registers an extension with the current message." }, { "SerializePartialToString", (PyCFunction)SerializePartialToString, - METH_NOARGS, + METH_VARARGS | METH_KEYWORDS, "Serializes the message to a string, even if it isn't initialized." }, - { "SerializeToString", (PyCFunction)SerializeToString, METH_NOARGS, + { "SerializeToString", (PyCFunction)SerializeToString, + METH_VARARGS | METH_KEYWORDS, "Serializes the message to a string, only for initialized messages." }, { "SetInParent", (PyCFunction)SetInParent, METH_NOARGS, "Sets the has bit of the given field in its parent message." }, @@ -2605,7 +2704,8 @@ static bool SetCompositeField( return PyDict_SetItem(self->composite_fields, name, value) == 0; } -PyObject* GetAttr(CMessage* self, PyObject* name) { +PyObject* GetAttr(PyObject* pself, PyObject* name) { + CMessage* self = reinterpret_cast<CMessage*>(pself); PyObject* value = self->composite_fields ? PyDict_GetItem(self->composite_fields, name) : NULL; if (value != NULL) { @@ -2624,8 +2724,8 @@ PyObject* GetAttr(CMessage* self, PyObject* name) { const Descriptor* entry_type = field_descriptor->message_type(); const FieldDescriptor* value_type = entry_type->FindFieldByName("value"); if (value_type->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - PyObject* value_class = cdescriptor_pool::GetMessageClass( - GetDescriptorPoolForMessage(self), value_type->message_type()); + CMessageClass* value_class = message_factory::GetMessageClass( + GetFactoryForMessage(self), value_type->message_type()); if (value_class == NULL) { return NULL; } @@ -2647,8 +2747,8 @@ PyObject* GetAttr(CMessage* self, PyObject* name) { if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) { PyObject* py_container = NULL; if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - PyObject *message_class = cdescriptor_pool::GetMessageClass( - GetDescriptorPoolForMessage(self), field_descriptor->message_type()); + CMessageClass* message_class = message_factory::GetMessageClass( + GetFactoryForMessage(self), field_descriptor->message_type()); if (message_class == NULL) { return NULL; } @@ -2683,7 +2783,8 @@ PyObject* GetAttr(CMessage* self, PyObject* name) { return InternalGetScalar(self->message, field_descriptor); } -int SetAttr(CMessage* self, PyObject* name, PyObject* value) { +int SetAttr(PyObject* pself, PyObject* name, PyObject* value) { + CMessage* self = reinterpret_cast<CMessage*>(pself); if (self->composite_fields && PyDict_Contains(self->composite_fields, name)) { PyErr_SetString(PyExc_TypeError, "Can't set composite field"); return -1; @@ -2711,7 +2812,7 @@ int SetAttr(CMessage* self, PyObject* name, PyObject* value) { PyErr_Format(PyExc_AttributeError, "Assignment not allowed " - "(no field \"%s\"in protocol message object).", + "(no field \"%s\" in protocol message object).", PyString_AsString(name)); return -1; } @@ -2719,7 +2820,7 @@ int SetAttr(CMessage* self, PyObject* name, PyObject* value) { } // namespace cmessage PyTypeObject CMessage_Type = { - PyVarObject_HEAD_INIT(&PyMessageMeta_Type, 0) + PyVarObject_HEAD_INIT(&CMessageClass_Type, 0) FULL_MODULE_NAME ".CMessage", // tp_name sizeof(CMessage), // tp_basicsize 0, // tp_itemsize @@ -2728,22 +2829,22 @@ PyTypeObject CMessage_Type = { 0, // tp_getattr 0, // tp_setattr 0, // tp_compare - 0, // tp_repr + (reprfunc)cmessage::ToStr, // tp_repr 0, // tp_as_number 0, // tp_as_sequence 0, // tp_as_mapping PyObject_HashNotImplemented, // tp_hash 0, // tp_call (reprfunc)cmessage::ToStr, // tp_str - (getattrofunc)cmessage::GetAttr, // tp_getattro - (setattrofunc)cmessage::SetAttr, // tp_setattro + cmessage::GetAttr, // tp_getattro + cmessage::SetAttr, // tp_setattro 0, // tp_as_buffer Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, // tp_flags "A ProtocolMessage", // tp_doc 0, // tp_traverse 0, // tp_clear (richcmpfunc)cmessage::RichCompare, // tp_richcompare - 0, // tp_weaklistoffset + offsetof(CMessage, weakreflist), // tp_weaklistoffset 0, // tp_iter 0, // tp_iternext cmessage::Methods, // tp_methods @@ -2765,17 +2866,38 @@ const Message* (*GetCProtoInsidePyProtoPtr)(PyObject* msg); Message* (*MutableCProtoInsidePyProtoPtr)(PyObject* msg); static const Message* GetCProtoInsidePyProtoImpl(PyObject* msg) { + const Message* message = PyMessage_GetMessagePointer(msg); + if (message == NULL) { + PyErr_Clear(); + return NULL; + } + return message; +} + +static Message* MutableCProtoInsidePyProtoImpl(PyObject* msg) { + Message* message = PyMessage_GetMutableMessagePointer(msg); + if (message == NULL) { + PyErr_Clear(); + return NULL; + } + return message; +} + +const Message* PyMessage_GetMessagePointer(PyObject* msg) { if (!PyObject_TypeCheck(msg, &CMessage_Type)) { + PyErr_SetString(PyExc_TypeError, "Not a Message instance"); return NULL; } CMessage* cmsg = reinterpret_cast<CMessage*>(msg); return cmsg->message; } -static Message* MutableCProtoInsidePyProtoImpl(PyObject* msg) { +Message* PyMessage_GetMutableMessagePointer(PyObject* msg) { if (!PyObject_TypeCheck(msg, &CMessage_Type)) { + PyErr_SetString(PyExc_TypeError, "Not a Message instance"); return NULL; } + CMessage* cmsg = reinterpret_cast<CMessage*>(msg); if ((cmsg->composite_fields && PyDict_Size(cmsg->composite_fields) != 0) || (cmsg->extensions != NULL && @@ -2784,36 +2906,20 @@ static Message* MutableCProtoInsidePyProtoImpl(PyObject* msg) { // the underlying C++ message back to the CMessage (e.g. removed repeated // composite containers). We only allow direct mutation of the underlying // C++ message if there is no child data in the CMessage. + PyErr_SetString(PyExc_ValueError, + "Cannot reliably get a mutable pointer " + "to a message with extra references"); return NULL; } cmessage::AssureWritable(cmsg); return cmsg->message; } -static const char module_docstring[] = -"python-proto2 is a module that can be used to enhance proto2 Python API\n" -"performance.\n" -"\n" -"It provides access to the protocol buffers C++ reflection API that\n" -"implements the basic protocol buffer functions."; - void InitGlobals() { // TODO(gps): Check all return values in this function for NULL and propagate // the error (MemoryError) on up to result in an import failure. These should // also be freed and reset to NULL during finalization. - kPythonZero = PyInt_FromLong(0); - kint32min_py = PyInt_FromLong(kint32min); - kint32max_py = PyInt_FromLong(kint32max); - kuint32max_py = PyLong_FromLongLong(kuint32max); - kint64min_py = PyLong_FromLongLong(kint64min); - kint64max_py = PyLong_FromLongLong(kint64max); - kuint64max_py = PyLong_FromUnsignedLongLong(kuint64max); - kDESCRIPTOR = PyString_FromString("DESCRIPTOR"); - k_cdescriptor = PyString_FromString("_cdescriptor"); - kfull_name = PyString_FromString("full_name"); - k_extensions_by_name = PyString_FromString("_extensions_by_name"); - k_extensions_by_number = PyString_FromString("_extensions_by_number"); PyObject *dummy_obj = PySet_New(NULL); kEmptyWeakref = PyWeakref_NewRef(dummy_obj, NULL); @@ -2831,15 +2937,20 @@ bool InitProto2MessageModule(PyObject *m) { return false; } + // Initialize types and globals in message_factory.cc + if (!InitMessageFactory()) { + return false; + } + // Initialize constants defined in this file. InitGlobals(); - PyMessageMeta_Type.tp_base = &PyType_Type; - if (PyType_Ready(&PyMessageMeta_Type) < 0) { + CMessageClass_Type.tp_base = &PyType_Type; + if (PyType_Ready(&CMessageClass_Type) < 0) { return false; } PyModule_AddObject(m, "MessageMeta", - reinterpret_cast<PyObject*>(&PyMessageMeta_Type)); + reinterpret_cast<PyObject*>(&CMessageClass_Type)); if (PyType_Ready(&CMessage_Type) < 0) { return false; @@ -2848,25 +2959,6 @@ bool InitProto2MessageModule(PyObject *m) { // DESCRIPTOR is set on each protocol buffer message class elsewhere, but set // it here as well to document that subclasses need to set it. PyDict_SetItem(CMessage_Type.tp_dict, kDESCRIPTOR, Py_None); - // Subclasses with message extensions will override _extensions_by_name and - // _extensions_by_number with fresh mutable dictionaries in AddDescriptors. - // All other classes can share this same immutable mapping. - ScopedPyObjectPtr empty_dict(PyDict_New()); - if (empty_dict == NULL) { - return false; - } - ScopedPyObjectPtr immutable_dict(PyDictProxy_New(empty_dict.get())); - if (immutable_dict == NULL) { - return false; - } - if (PyDict_SetItem(CMessage_Type.tp_dict, - k_extensions_by_name, immutable_dict.get()) < 0) { - return false; - } - if (PyDict_SetItem(CMessage_Type.tp_dict, - k_extensions_by_number, immutable_dict.get()) < 0) { - return false; - } PyModule_AddObject(m, "Message", reinterpret_cast<PyObject*>(&CMessage_Type)); @@ -2912,69 +3004,15 @@ bool InitProto2MessageModule(PyObject *m) { } // Initialize Map container types. - { - // ScalarMapContainer_Type derives from our MutableMapping type. - ScopedPyObjectPtr containers(PyImport_ImportModule( - "google.protobuf.internal.containers")); - if (containers == NULL) { - return false; - } - - ScopedPyObjectPtr mutable_mapping( - PyObject_GetAttrString(containers.get(), "MutableMapping")); - if (mutable_mapping == NULL) { - return false; - } - - if (!PyObject_TypeCheck(mutable_mapping.get(), &PyType_Type)) { - return false; - } - - Py_INCREF(mutable_mapping.get()); -#if PY_MAJOR_VERSION >= 3 - PyObject* bases = PyTuple_New(1); - PyTuple_SET_ITEM(bases, 0, mutable_mapping.get()); - - ScalarMapContainer_Type = - PyType_FromSpecWithBases(&ScalarMapContainer_Type_spec, bases); - PyModule_AddObject(m, "ScalarMapContainer", ScalarMapContainer_Type); -#else - ScalarMapContainer_Type.tp_base = - reinterpret_cast<PyTypeObject*>(mutable_mapping.get()); - - if (PyType_Ready(&ScalarMapContainer_Type) < 0) { - return false; - } - - PyModule_AddObject(m, "ScalarMapContainer", - reinterpret_cast<PyObject*>(&ScalarMapContainer_Type)); -#endif - - if (PyType_Ready(&MapIterator_Type) < 0) { - return false; - } - - PyModule_AddObject(m, "MapIterator", - reinterpret_cast<PyObject*>(&MapIterator_Type)); - - -#if PY_MAJOR_VERSION >= 3 - MessageMapContainer_Type = - PyType_FromSpecWithBases(&MessageMapContainer_Type_spec, bases); - PyModule_AddObject(m, "MessageMapContainer", MessageMapContainer_Type); -#else - Py_INCREF(mutable_mapping.get()); - MessageMapContainer_Type.tp_base = - reinterpret_cast<PyTypeObject*>(mutable_mapping.get()); - - if (PyType_Ready(&MessageMapContainer_Type) < 0) { - return false; - } - - PyModule_AddObject(m, "MessageMapContainer", - reinterpret_cast<PyObject*>(&MessageMapContainer_Type)); -#endif + if (!InitMapContainers()) { + return false; } + PyModule_AddObject(m, "ScalarMapContainer", + reinterpret_cast<PyObject*>(ScalarMapContainer_Type)); + PyModule_AddObject(m, "MessageMapContainer", + reinterpret_cast<PyObject*>(MessageMapContainer_Type)); + PyModule_AddObject(m, "MapIterator", + reinterpret_cast<PyObject*>(&MapIterator_Type)); if (PyType_Ready(&ExtensionDict_Type) < 0) { return false; @@ -3009,6 +3047,10 @@ bool InitProto2MessageModule(PyObject *m) { &PyFileDescriptor_Type)); PyModule_AddObject(m, "OneofDescriptor", reinterpret_cast<PyObject*>( &PyOneofDescriptor_Type)); + PyModule_AddObject(m, "ServiceDescriptor", reinterpret_cast<PyObject*>( + &PyServiceDescriptor_Type)); + PyModule_AddObject(m, "MethodDescriptor", reinterpret_cast<PyObject*>( + &PyMethodDescriptor_Type)); PyObject* enum_type_wrapper = PyImport_ImportModule( "google.protobuf.internal.enum_type_wrapper"); @@ -3045,47 +3087,4 @@ bool InitProto2MessageModule(PyObject *m) { } // namespace python } // namespace protobuf - - -#if PY_MAJOR_VERSION >= 3 -static struct PyModuleDef _module = { - PyModuleDef_HEAD_INIT, - "_message", - google::protobuf::python::module_docstring, - -1, - NULL, - NULL, - NULL, - NULL, - NULL -}; -#define INITFUNC PyInit__message -#define INITFUNC_ERRORVAL NULL -#else // Python 2 -#define INITFUNC init_message -#define INITFUNC_ERRORVAL -#endif - -extern "C" { - PyMODINIT_FUNC INITFUNC(void) { - PyObject* m; -#if PY_MAJOR_VERSION >= 3 - m = PyModule_Create(&_module); -#else - m = Py_InitModule3("_message", NULL, google::protobuf::python::module_docstring); -#endif - if (m == NULL) { - return INITFUNC_ERRORVAL; - } - - if (!google::protobuf::python::InitProto2MessageModule(m)) { - Py_DECREF(m); - return INITFUNC_ERRORVAL; - } - -#if PY_MAJOR_VERSION >= 3 - return m; -#endif - } -} } // namespace google |