From 5a76e633ea9b5adb215e93fdc11e1c0c08b3fc74 Mon Sep 17 00:00:00 2001 From: Adam Cozzette Date: Thu, 17 Nov 2016 16:48:38 -0800 Subject: Integrated internal changes from Google --- python/google/protobuf/pyext/descriptor_pool.cc | 67 +++++++++++++++++++++++++ 1 file changed, 67 insertions(+) (limited to 'python/google/protobuf/pyext/descriptor_pool.cc') diff --git a/python/google/protobuf/pyext/descriptor_pool.cc b/python/google/protobuf/pyext/descriptor_pool.cc index a42e5431..fa66bf9a 100644 --- a/python/google/protobuf/pyext/descriptor_pool.cc +++ b/python/google/protobuf/pyext/descriptor_pool.cc @@ -319,6 +319,51 @@ PyObject* FindFileContainingSymbol(PyDescriptorPool* self, PyObject* arg) { return PyFileDescriptor_FromDescriptor(file_descriptor); } +PyObject* FindExtensionByNumber(PyDescriptorPool* self, PyObject* args) { + PyObject* message_descriptor; + int number; + if (!PyArg_ParseTuple(args, "Oi", &message_descriptor, &number)) { + return NULL; + } + const Descriptor* descriptor = PyMessageDescriptor_AsDescriptor( + message_descriptor); + if (descriptor == NULL) { + return NULL; + } + + const FieldDescriptor* extension_descriptor = + self->pool->FindExtensionByNumber(descriptor, number); + if (extension_descriptor == NULL) { + PyErr_Format(PyExc_KeyError, "Couldn't find extension %d", number); + return NULL; + } + + return PyFieldDescriptor_FromDescriptor(extension_descriptor); +} + +PyObject* FindAllExtensions(PyDescriptorPool* self, PyObject* arg) { + const Descriptor* descriptor = PyMessageDescriptor_AsDescriptor(arg); + if (descriptor == NULL) { + return NULL; + } + + std::vector extensions; + self->pool->FindAllExtensions(descriptor, &extensions); + + ScopedPyObjectPtr result(PyList_New(extensions.size())); + if (result == NULL) { + return NULL; + } + for (int i = 0; i < extensions.size(); i++) { + PyObject* extension = PyFieldDescriptor_FromDescriptor(extensions[i]); + if (extension == NULL) { + return NULL; + } + PyList_SET_ITEM(result.get(), i, extension); // Steals the reference. + } + return result.release(); +} + // These functions should not exist -- the only valid way to create // descriptors is to call Add() or AddSerializedFile(). // But these AddDescriptor() functions were created in Python and some people @@ -376,6 +421,22 @@ PyObject* AddEnumDescriptor(PyDescriptorPool* self, PyObject* descriptor) { Py_RETURN_NONE; } +PyObject* AddExtensionDescriptor(PyDescriptorPool* self, PyObject* descriptor) { + const FieldDescriptor* extension_descriptor = + PyFieldDescriptor_AsDescriptor(descriptor); + if (!extension_descriptor) { + return NULL; + } + if (extension_descriptor != + self->pool->FindExtensionByName(extension_descriptor->full_name())) { + PyErr_Format(PyExc_ValueError, + "The extension descriptor %s does not belong to this pool", + extension_descriptor->full_name().c_str()); + return NULL; + } + Py_RETURN_NONE; +} + // The code below loads new Descriptors from a serialized FileDescriptorProto. @@ -475,6 +536,8 @@ static PyMethodDef Methods[] = { "No-op. Add() must have been called before." }, { "AddEnumDescriptor", (PyCFunction)AddEnumDescriptor, METH_O, "No-op. Add() must have been called before." }, + { "AddExtensionDescriptor", (PyCFunction)AddExtensionDescriptor, METH_O, + "No-op. Add() must have been called before." }, { "FindFileByName", (PyCFunction)FindFileByName, METH_O, "Searches for a file descriptor by its .proto name." }, @@ -495,6 +558,10 @@ static PyMethodDef Methods[] = { { "FindFileContainingSymbol", (PyCFunction)FindFileContainingSymbol, METH_O, "Gets the FileDescriptor containing the specified symbol." }, + { "FindExtensionByNumber", (PyCFunction)FindExtensionByNumber, METH_VARARGS, + "Gets the extension descriptor for the given number." }, + { "FindAllExtensions", (PyCFunction)FindAllExtensions, METH_O, + "Gets all known extensions of the given message descriptor." }, {NULL} }; -- cgit v1.2.3