aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/lib
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-03-30 14:56:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-30 14:58:45 -0700
commit97731cb122f53552bd15351e046a256f78cca444 (patch)
treedbec22e40bc41f690953b7aad5498516cf457908 /tensorflow/python/lib
parent15c10899c9c0e1717251b380330cc248b2c76c9c (diff)
Raise exception in SWIG on bad TF_Status from C API.
This change provides an alternative mechanism to tf.raise_exception_on_not_ok_status(), which is inefficient and error-prone (people often use the status multiple times in the with block, but it's only checked when the context manager exits). Instead, it uses SWIG to automatically raise an exception when a C API method fails. Note that this removes the status argument from affected methods. For now, I've only applied this typemap to C API methods. It would be good to expand this to all uses of raise_exception_on_not_ok_status. PiperOrigin-RevId: 191121016
Diffstat (limited to 'tensorflow/python/lib')
-rw-r--r--tensorflow/python/lib/core/py_exception_registry.cc50
-rw-r--r--tensorflow/python/lib/core/py_exception_registry.h73
-rw-r--r--tensorflow/python/lib/core/py_exception_registry.i28
3 files changed, 151 insertions, 0 deletions
diff --git a/tensorflow/python/lib/core/py_exception_registry.cc b/tensorflow/python/lib/core/py_exception_registry.cc
new file mode 100644
index 0000000000..6637de632b
--- /dev/null
+++ b/tensorflow/python/lib/core/py_exception_registry.cc
@@ -0,0 +1,50 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/python/lib/core/py_exception_registry.h"
+
+#include <Python.h>
+
+namespace tensorflow {
+
+PyExceptionRegistry* PyExceptionRegistry::singleton_ = nullptr;
+
+void PyExceptionRegistry::Init(PyObject* code_to_exc_type_map) {
+ DCHECK(singleton_ == nullptr) << "PyExceptionRegistry::Init() already called";
+ singleton_ = new PyExceptionRegistry;
+
+ DCHECK(PyDict_Check(code_to_exc_type_map));
+ PyObject* key;
+ PyObject* value;
+ Py_ssize_t pos = 0;
+ while (PyDict_Next(code_to_exc_type_map, &pos, &key, &value)) {
+ TF_Code code = static_cast<TF_Code>(PyLong_AsLong(key));
+ singleton_->exc_types_[code] = value;
+ // The exception classes should also have the lifetime of the process, but
+ // incref just in case.
+ Py_INCREF(value);
+ }
+}
+
+PyObject* PyExceptionRegistry::Lookup(TF_Code code) {
+ DCHECK(singleton_ != nullptr) << "Must call PyExceptionRegistry::Init() "
+ "before PyExceptionRegistry::Lookup()";
+ DCHECK_NE(code, TF_OK);
+ DCHECK(singleton_->exc_types_.find(code) != singleton_->exc_types_.end())
+ << "Unknown error code passed to PyExceptionRegistry::Lookup: " << code;
+ return singleton_->exc_types_[code];
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/python/lib/core/py_exception_registry.h b/tensorflow/python/lib/core/py_exception_registry.h
new file mode 100644
index 0000000000..2b0f23b548
--- /dev/null
+++ b/tensorflow/python/lib/core/py_exception_registry.h
@@ -0,0 +1,73 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_PYTHON_LIB_CORE_PY_EXCEPTION_REGISTRY_H_
+#define TENSORFLOW_PYTHON_LIB_CORE_PY_EXCEPTION_REGISTRY_H_
+
+#include <map>
+
+#include "tensorflow/c/c_api.h"
+#include "tensorflow/core/platform/logging.h"
+
+#ifndef PyObject_HEAD
+struct _object;
+typedef _object PyObject;
+#endif
+
+namespace tensorflow {
+
+// Global registry mapping C API error codes to the corresponding custom Python
+// exception type. This is used to expose the exception types to C extension
+// code (i.e. so we can raise custom exceptions via SWIG).
+//
+// Init() must be called exactly once at the beginning of the process before
+// Lookup() can be used.
+//
+// Example usage:
+// TF_Status* status = TF_NewStatus();
+// TF_Foo(..., status);
+//
+// if (TF_GetCode(status) != TF_OK) {
+// PyObject* exc_type = PyExceptionRegistry::Lookup(TF_GetCode(status));
+// // Arguments to OpError base class. Set `node_def` and `op` to None.
+// PyObject* args =
+// Py_BuildValue("sss", nullptr, nullptr, TF_Message(status));
+// PyErr_SetObject(exc_type, args);
+// Py_DECREF(args);
+// TF_DeleteStatus(status);
+// return NULL;
+// }
+class PyExceptionRegistry {
+ public:
+ // Initializes the process-wide registry. Should be called exactly once near
+ // the beginning of the process. The arguments are the various Python
+ // exception types (e.g. `cancelled_exc` corresponds to
+ // errors.CancelledError).
+ static void Init(PyObject* code_to_exc_type_map);
+
+ // Returns the Python exception type corresponding to `code`. Init() must be
+ // called before using this function. `code` should not be TF_OK.
+ static PyObject* Lookup(TF_Code code);
+
+ private:
+ static PyExceptionRegistry* singleton_;
+ PyExceptionRegistry() = default;
+
+ // Maps error codes to the corresponding Python exception type.
+ std::map<TF_Code, PyObject*> exc_types_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_PYTHON_LIB_CORE_PY_EXCEPTION_REGISTRY_H_
diff --git a/tensorflow/python/lib/core/py_exception_registry.i b/tensorflow/python/lib/core/py_exception_registry.i
new file mode 100644
index 0000000000..e872b74985
--- /dev/null
+++ b/tensorflow/python/lib/core/py_exception_registry.i
@@ -0,0 +1,28 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+%include "tensorflow/python/platform/base.i"
+
+%{
+#include "tensorflow/python/lib/core/py_exception_registry.h"
+%}
+
+%ignoreall
+
+%unignore tensorflow::PyExceptionRegistry;
+%unignore tensorflow::PyExceptionRegistry::Init;
+
+%include "tensorflow/python/lib/core/py_exception_registry.h"
+%unignoreall