diff options
Diffstat (limited to 'python/google/protobuf/pyext/map_container.cc')
-rw-r--r-- | python/google/protobuf/pyext/map_container.cc | 166 |
1 files changed, 138 insertions, 28 deletions
diff --git a/python/google/protobuf/pyext/map_container.cc b/python/google/protobuf/pyext/map_container.cc index df9138a4..6d7ee285 100644 --- a/python/google/protobuf/pyext/map_container.cc +++ b/python/google/protobuf/pyext/map_container.cc @@ -32,13 +32,16 @@ #include <google/protobuf/pyext/map_container.h> +#include <memory> + #include <google/protobuf/stubs/logging.h> #include <google/protobuf/stubs/common.h> -#include <google/protobuf/stubs/scoped_ptr.h> #include <google/protobuf/map_field.h> #include <google/protobuf/map.h> #include <google/protobuf/message.h> +#include <google/protobuf/pyext/message_factory.h> #include <google/protobuf/pyext/message.h> +#include <google/protobuf/pyext/repeated_composite_container.h> #include <google/protobuf/pyext/scoped_pyobject_ptr.h> #if PY_MAJOR_VERSION >= 3 @@ -70,7 +73,7 @@ class MapReflectionFriend { struct MapIterator { PyObject_HEAD; - scoped_ptr< ::google::protobuf::MapIterator> iter; + std::unique_ptr<::google::protobuf::MapIterator> iter; // A pointer back to the container, so we can notice changes to the version. // We own a ref on this. @@ -88,7 +91,7 @@ struct MapIterator { // as this iterator does. This is solely for the benefit of the MapIterator // destructor -- we should never actually access the iterator in this state // except to delete it. - shared_ptr<Message> owner; + CMessage::OwnerRef owner; // The version of the map when we took the iterator to it. // @@ -324,6 +327,33 @@ PyObject* Clear(PyObject* _self) { Py_RETURN_NONE; } +PyObject* GetEntryClass(PyObject* _self) { + MapContainer* self = GetMap(_self); + CMessageClass* message_class = message_factory::GetMessageClass( + cmessage::GetFactoryForMessage(self->parent), + self->parent_field_descriptor->message_type()); + Py_XINCREF(message_class); + return reinterpret_cast<PyObject*>(message_class); +} + +PyObject* MergeFrom(PyObject* _self, PyObject* arg) { + MapContainer* self = GetMap(_self); + MapContainer* other_map = GetMap(arg); + Message* message = self->GetMutableMessage(); + const Message* other_message = other_map->message; + const Reflection* reflection = message->GetReflection(); + const Reflection* other_reflection = other_message->GetReflection(); + int count = other_reflection->FieldSize( + *other_message, other_map->parent_field_descriptor); + for (int i = 0 ; i < count; i ++) { + reflection->AddMessage(message, self->parent_field_descriptor)->MergeFrom( + other_reflection->GetRepeatedMessage( + *other_message, other_map->parent_field_descriptor, i)); + } + self->version++; + Py_RETURN_NONE; +} + PyObject* MapReflectionFriend::Contains(PyObject* _self, PyObject* key) { MapContainer* self = GetMap(_self); @@ -344,9 +374,10 @@ PyObject* MapReflectionFriend::Contains(PyObject* _self, PyObject* key) { } // Initializes the underlying Message object of "to" so it becomes a new parent -// repeated scalar, and copies all the values from "from" to it. A child scalar +// map container, and copies all the values from "from" to it. A child map // container can be released by passing it as both from and to (e.g. making it // the recipient of the new parent message and copying the values from itself). +// In fact, this is the only supported use at the moment. static int InitializeAndCopyToParentContainer(MapContainer* from, MapContainer* to) { // For now we require from == to, re-evaluate if we want to support deep copy @@ -358,7 +389,7 @@ static int InitializeAndCopyToParentContainer(MapContainer* from, // A somewhat roundabout way of copying just one field from old_message to // new_message. This is the best we can do with what Reflection gives us. Message* mutable_old = from->GetMutableMessage(); - vector<const FieldDescriptor*> fields; + std::vector<const FieldDescriptor*> fields; fields.push_back(from->parent_field_descriptor); // Move the map field into the new message. @@ -395,12 +426,7 @@ PyObject *NewScalarMapContainer( return NULL; } -#if PY_MAJOR_VERSION >= 3 - ScopedPyObjectPtr obj(PyType_GenericAlloc( - reinterpret_cast<PyTypeObject *>(ScalarMapContainer_Type), 0)); -#else - ScopedPyObjectPtr obj(PyType_GenericAlloc(&ScalarMapContainer_Type, 0)); -#endif + ScopedPyObjectPtr obj(PyType_GenericAlloc(ScalarMapContainer_Type, 0)); if (obj.get() == NULL) { return PyErr_Format(PyExc_RuntimeError, "Could not allocate new container."); @@ -522,6 +548,10 @@ static PyMethodDef ScalarMapMethods[] = { "Removes all elements from the map." }, { "get", ScalarMapGet, METH_VARARGS, "Gets the value for the given key if present, or otherwise a default" }, + { "GetEntryClass", (PyCFunction)GetEntryClass, METH_NOARGS, + "Return the class used to build Entries of (key, value) pairs." }, + { "MergeFrom", (PyCFunction)MergeFrom, METH_O, + "Merges a map into the current map." }, /* { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS, "Makes a deep copy of the class." }, @@ -531,6 +561,7 @@ static PyMethodDef ScalarMapMethods[] = { {NULL, NULL}, }; +PyTypeObject *ScalarMapContainer_Type; #if PY_MAJOR_VERSION >= 3 static PyType_Slot ScalarMapContainer_Type_slots[] = { {Py_tp_dealloc, (void *)ScalarMapDealloc}, @@ -549,7 +580,6 @@ static PyMethodDef ScalarMapMethods[] = { Py_TPFLAGS_DEFAULT, ScalarMapContainer_Type_slots }; - PyObject *ScalarMapContainer_Type; #else static PyMappingMethods ScalarMapMappingMethods = { MapReflectionFriend::Length, // mp_length @@ -557,7 +587,7 @@ static PyMethodDef ScalarMapMethods[] = { MapReflectionFriend::ScalarMapSetItem, // mp_ass_subscript }; - PyTypeObject ScalarMapContainer_Type = { + PyTypeObject _ScalarMapContainer_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME ".ScalarMapContainer", // tp_name sizeof(MapContainer), // tp_basicsize @@ -610,8 +640,7 @@ static PyObject* GetCMessage(MessageMapContainer* self, Message* message) { PyObject* ret = PyDict_GetItem(self->message_dict, key.get()); if (ret == NULL) { - CMessage* cmsg = cmessage::NewEmptyMessage(self->subclass_init, - message->GetDescriptor()); + CMessage* cmsg = cmessage::NewEmptyMessage(self->message_class); ret = reinterpret_cast<PyObject*>(cmsg); if (cmsg == NULL) { @@ -634,17 +663,12 @@ static PyObject* GetCMessage(MessageMapContainer* self, Message* message) { PyObject* NewMessageMapContainer( CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor, - PyObject* concrete_class) { + CMessageClass* message_class) { if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) { return NULL; } -#if PY_MAJOR_VERSION >= 3 - PyObject* obj = PyType_GenericAlloc( - reinterpret_cast<PyTypeObject *>(MessageMapContainer_Type), 0); -#else - PyObject* obj = PyType_GenericAlloc(&MessageMapContainer_Type, 0); -#endif + PyObject* obj = PyType_GenericAlloc(MessageMapContainer_Type, 0); if (obj == NULL) { return PyErr_Format(PyExc_RuntimeError, "Could not allocate new container."); @@ -669,8 +693,8 @@ PyObject* NewMessageMapContainer( "Could not allocate message dict."); } - Py_INCREF(concrete_class); - self->subclass_init = concrete_class; + Py_INCREF(message_class); + self->message_class = message_class; if (self->key_field_descriptor == NULL || self->value_field_descriptor == NULL) { @@ -705,8 +729,33 @@ int MapReflectionFriend::MessageMapSetItem(PyObject* _self, PyObject* key, } // Delete key from map. - if (reflection->DeleteMapValue(message, self->parent_field_descriptor, + if (reflection->ContainsMapKey(*message, self->parent_field_descriptor, map_key)) { + // Delete key from CMessage dict. + MapValueRef value; + reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor, + map_key, &value); + ScopedPyObjectPtr key(PyLong_FromVoidPtr(value.MutableMessageValue())); + + PyObject* cmsg_value = PyDict_GetItem(self->message_dict, key.get()); + if (cmsg_value) { + // Need to keep CMessage stay alive if it is still referenced after + // deletion. Makes a new message and swaps values into CMessage + // instead of just removing. + CMessage* cmsg = reinterpret_cast<CMessage*>(cmsg_value); + Message* msg = cmsg->message; + cmsg->owner.reset(msg->New()); + cmsg->message = cmsg->owner.get(); + cmsg->parent = NULL; + msg->GetReflection()->Swap(msg, cmsg->message); + if (PyDict_DelItem(self->message_dict, key.get()) < 0) { + return -1; + } + } + + // Delete key from map. + reflection->DeleteMapValue(message, self->parent_field_descriptor, + map_key); return 0; } else { PyErr_Format(PyExc_KeyError, "Key not present in map"); @@ -763,6 +812,7 @@ static void MessageMapDealloc(PyObject* _self) { MessageMapContainer* self = GetMessageMap(_self); self->owner.reset(); Py_DECREF(self->message_dict); + Py_DECREF(self->message_class); Py_TYPE(_self)->tp_free(_self); } @@ -775,6 +825,10 @@ static PyMethodDef MessageMapMethods[] = { "Gets the value for the given key if present, or otherwise a default" }, { "get_or_create", MapReflectionFriend::MessageMapGetItem, METH_O, "Alias for getitem, useful to make explicit that the map is mutated." }, + { "GetEntryClass", (PyCFunction)GetEntryClass, METH_NOARGS, + "Return the class used to build Entries of (key, value) pairs." }, + { "MergeFrom", (PyCFunction)MergeFrom, METH_O, + "Merges a map into the current map." }, /* { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS, "Makes a deep copy of the class." }, @@ -784,6 +838,7 @@ static PyMethodDef MessageMapMethods[] = { {NULL, NULL}, }; +PyTypeObject *MessageMapContainer_Type; #if PY_MAJOR_VERSION >= 3 static PyType_Slot MessageMapContainer_Type_slots[] = { {Py_tp_dealloc, (void *)MessageMapDealloc}, @@ -802,8 +857,6 @@ static PyMethodDef MessageMapMethods[] = { Py_TPFLAGS_DEFAULT, MessageMapContainer_Type_slots }; - - PyObject *MessageMapContainer_Type; #else static PyMappingMethods MessageMapMappingMethods = { MapReflectionFriend::Length, // mp_length @@ -811,7 +864,7 @@ static PyMethodDef MessageMapMethods[] = { MapReflectionFriend::MessageMapSetItem, // mp_ass_subscript }; - PyTypeObject MessageMapContainer_Type = { + PyTypeObject _MessageMapContainer_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME ".MessageMapContainer", // tp_name sizeof(MessageMapContainer), // tp_basicsize @@ -960,6 +1013,63 @@ PyTypeObject MapIterator_Type = { 0, // tp_init }; +bool InitMapContainers() { + // ScalarMapContainer_Type derives from our MutableMapping type. + ScopedPyObjectPtr containers(PyImport_ImportModule( + "google.protobuf.internal.containers")); + if (containers == NULL) { + return false; + } + + ScopedPyObjectPtr mutable_mapping( + PyObject_GetAttrString(containers.get(), "MutableMapping")); + if (mutable_mapping == NULL) { + return false; + } + + if (!PyObject_TypeCheck(mutable_mapping.get(), &PyType_Type)) { + return false; + } + + Py_INCREF(mutable_mapping.get()); +#if PY_MAJOR_VERSION >= 3 + PyObject* bases = PyTuple_New(1); + PyTuple_SET_ITEM(bases, 0, mutable_mapping.get()); + + ScalarMapContainer_Type = reinterpret_cast<PyTypeObject*>( + PyType_FromSpecWithBases(&ScalarMapContainer_Type_spec, bases)); +#else + _ScalarMapContainer_Type.tp_base = + reinterpret_cast<PyTypeObject*>(mutable_mapping.get()); + + if (PyType_Ready(&_ScalarMapContainer_Type) < 0) { + return false; + } + + ScalarMapContainer_Type = &_ScalarMapContainer_Type; +#endif + + if (PyType_Ready(&MapIterator_Type) < 0) { + return false; + } + +#if PY_MAJOR_VERSION >= 3 + MessageMapContainer_Type = reinterpret_cast<PyTypeObject*>( + PyType_FromSpecWithBases(&MessageMapContainer_Type_spec, bases)); +#else + Py_INCREF(mutable_mapping.get()); + _MessageMapContainer_Type.tp_base = + reinterpret_cast<PyTypeObject*>(mutable_mapping.get()); + + if (PyType_Ready(&_MessageMapContainer_Type) < 0) { + return false; + } + + MessageMapContainer_Type = &_MessageMapContainer_Type; +#endif + return true; +} + } // namespace python } // namespace protobuf } // namespace google |